eXNN.InnerNeuralTopology.homologies
1import torch 2import matplotlib.pyplot as plt 3from typing import List, Dict 4from gtda.homology import VietorisRipsPersistence, WeakAlphaPersistence,\ 5 SparseRipsPersistence 6 7 8def GetActivation(model: torch.nn.Module, 9 x: torch.Tensor, 10 layer: str): 11 activation = {} 12 13 def getActivation(name): 14 def hook(model, input, output): 15 activation[name] = output.detach() 16 17 return hook 18 19 h1 = getattr(model, layer).register_forward_hook(getActivation(layer)) 20 model.forward(x) 21 h1.remove() 22 return activation[layer] 23 24 25def diagram_to_barcode(plot): 26 data = plot['data'] 27 homologies = {} 28 for h in data: 29 if h['name'] is None: 30 continue 31 homologies[h['name']] = list(zip(h['x'], h['y'])) 32 33 for h in homologies.keys(): 34 homologies[h] = sorted(homologies[h], key=lambda x: x[0]) 35 return homologies 36 37 38def plot_barcode(barcode: Dict): 39 homologies = list(barcode.keys()) 40 nplots = len(homologies) 41 fig, ax = plt.subplots(nplots, figsize=(15, min(10, 5 * nplots))) 42 43 if nplots == 1: 44 ax = [ax] 45 46 for i in range(nplots): 47 name = homologies[i] 48 ax[i].set_title(name) 49 ax[i].set_ylim([-0.05, 1.05]) 50 bars = barcode[name] 51 n = len(bars) 52 for j in range(n): 53 bar = bars[j] 54 ax[i].plot([bar[0], bar[1]], [j / n, j / n], color='black') 55 labels = ["" for _ in range(len(ax[i].get_yticklabels()))] 56 ax[i].set_yticks(ax[i].get_yticks()) 57 ax[i].set_yticklabels(labels) 58 59 if nplots == 1: 60 ax = ax[0] 61 plt.close(fig) 62 return fig 63 64 65def _ComputeBarcode(data: torch.Tensor, 66 hom_type: str, 67 coefs_type: str): 68 if hom_type == "standard": 69 VR = VietorisRipsPersistence( 70 homology_dimensions=[0], collapse_edges=True, coeff=int(coefs_type)) 71 elif hom_type == "sparse": 72 VR = SparseRipsPersistence(homology_dimensions=[0], coeff=int(coefs_type)) 73 elif hom_type == "weak": 74 VR = WeakAlphaPersistence( 75 homology_dimensions=[0], collapse_edges=True, coeff=int(coefs_type)) 76 else: 77 assert False, "hom_type must be one of: \"standard\", \"sparse\", \"weak\"!" 78 79 if len(data.shape) > 2: 80 data = torch.nn.Flatten()(data) 81 data = data.reshape(1, *data.shape) 82 diagrams = VR.fit_transform(data) 83 plot = VR.plot(diagrams) 84 barcode = diagram_to_barcode(plot) 85 return plot_barcode(barcode) 86 87 88def InnerNetspaceHomologies(model: torch.nn.Module, 89 x: torch.Tensor, 90 layer: str, 91 hom_type: str, 92 coefs_type: str): 93 act = GetActivation(model, x, layer) 94 plot = _ComputeBarcode(act, hom_type, coefs_type) 95 return plot 96 97 98def InnerNetspaceHomologiesExperimental(model: torch.nn.Module, 99 x: torch.Tensor, 100 layer: str, 101 dimensions: List[int] = [0], 102 make_barplot: bool = True, 103 rm_empty: bool = True): 104 act = GetActivation(model, x, layer) 105 act = act.reshape(1, *act.shape) 106 # Dimensions must not be outside layer dimensionality 107 N = act.shape[-1] 108 dimensions = [i if i >= 0 else N + i for i in dimensions] 109 dimensions = [i for i in dimensions if ((i >= 0) and (i < N))] 110 dimensions = list(set(dimensions)) 111 VR = VietorisRipsPersistence(homology_dimensions=dimensions, 112 collapse_edges=True) 113 diagrams = VR.fit_transform(act) 114 plot = VR.plot(diagrams) 115 if make_barplot: 116 barcode = diagram_to_barcode(plot) 117 if rm_empty: 118 barcode = {key: val for key, val in barcode.items() if len(val) > 0} 119 return plot_barcode(barcode) 120 else: 121 return plot
def
GetActivation(model: torch.nn.modules.module.Module, x: torch.Tensor, layer: str):
9def GetActivation(model: torch.nn.Module, 10 x: torch.Tensor, 11 layer: str): 12 activation = {} 13 14 def getActivation(name): 15 def hook(model, input, output): 16 activation[name] = output.detach() 17 18 return hook 19 20 h1 = getattr(model, layer).register_forward_hook(getActivation(layer)) 21 model.forward(x) 22 h1.remove() 23 return activation[layer]
def
diagram_to_barcode(plot):
def
plot_barcode(barcode: Dict):
39def plot_barcode(barcode: Dict): 40 homologies = list(barcode.keys()) 41 nplots = len(homologies) 42 fig, ax = plt.subplots(nplots, figsize=(15, min(10, 5 * nplots))) 43 44 if nplots == 1: 45 ax = [ax] 46 47 for i in range(nplots): 48 name = homologies[i] 49 ax[i].set_title(name) 50 ax[i].set_ylim([-0.05, 1.05]) 51 bars = barcode[name] 52 n = len(bars) 53 for j in range(n): 54 bar = bars[j] 55 ax[i].plot([bar[0], bar[1]], [j / n, j / n], color='black') 56 labels = ["" for _ in range(len(ax[i].get_yticklabels()))] 57 ax[i].set_yticks(ax[i].get_yticks()) 58 ax[i].set_yticklabels(labels) 59 60 if nplots == 1: 61 ax = ax[0] 62 plt.close(fig) 63 return fig
def
InnerNetspaceHomologies( model: torch.nn.modules.module.Module, x: torch.Tensor, layer: str, hom_type: str, coefs_type: str):
def
InnerNetspaceHomologiesExperimental( model: torch.nn.modules.module.Module, x: torch.Tensor, layer: str, dimensions: List[int] = [0], make_barplot: bool = True, rm_empty: bool = True):
99def InnerNetspaceHomologiesExperimental(model: torch.nn.Module, 100 x: torch.Tensor, 101 layer: str, 102 dimensions: List[int] = [0], 103 make_barplot: bool = True, 104 rm_empty: bool = True): 105 act = GetActivation(model, x, layer) 106 act = act.reshape(1, *act.shape) 107 # Dimensions must not be outside layer dimensionality 108 N = act.shape[-1] 109 dimensions = [i if i >= 0 else N + i for i in dimensions] 110 dimensions = [i for i in dimensions if ((i >= 0) and (i < N))] 111 dimensions = list(set(dimensions)) 112 VR = VietorisRipsPersistence(homology_dimensions=dimensions, 113 collapse_edges=True) 114 diagrams = VR.fit_transform(act) 115 plot = VR.plot(diagrams) 116 if make_barplot: 117 barcode = diagram_to_barcode(plot) 118 if rm_empty: 119 barcode = {key: val for key, val in barcode.items() if len(val) > 0} 120 return plot_barcode(barcode) 121 else: 122 return plot