eXNN.InnerNeuralTopology.homologies

  1import torch
  2import matplotlib.pyplot as plt
  3from typing import List, Dict
  4from gtda.homology import VietorisRipsPersistence, WeakAlphaPersistence,\
  5    SparseRipsPersistence
  6
  7
  8def GetActivation(model: torch.nn.Module,
  9                  x: torch.Tensor,
 10                  layer: str):
 11    activation = {}
 12
 13    def getActivation(name):
 14        def hook(model, input, output):
 15            activation[name] = output.detach()
 16
 17        return hook
 18
 19    h1 = getattr(model, layer).register_forward_hook(getActivation(layer))
 20    model.forward(x)
 21    h1.remove()
 22    return activation[layer]
 23
 24
 25def diagram_to_barcode(plot):
 26    data = plot['data']
 27    homologies = {}
 28    for h in data:
 29        if h['name'] is None:
 30            continue
 31        homologies[h['name']] = list(zip(h['x'], h['y']))
 32
 33    for h in homologies.keys():
 34        homologies[h] = sorted(homologies[h], key=lambda x: x[0])
 35    return homologies
 36
 37
 38def plot_barcode(barcode: Dict):
 39    homologies = list(barcode.keys())
 40    nplots = len(homologies)
 41    fig, ax = plt.subplots(nplots, figsize=(15, min(10, 5 * nplots)))
 42
 43    if nplots == 1:
 44        ax = [ax]
 45
 46    for i in range(nplots):
 47        name = homologies[i]
 48        ax[i].set_title(name)
 49        ax[i].set_ylim([-0.05, 1.05])
 50        bars = barcode[name]
 51        n = len(bars)
 52        for j in range(n):
 53            bar = bars[j]
 54            ax[i].plot([bar[0], bar[1]], [j / n, j / n], color='black')
 55        labels = ["" for _ in range(len(ax[i].get_yticklabels()))]
 56        ax[i].set_yticks(ax[i].get_yticks())
 57        ax[i].set_yticklabels(labels)
 58
 59    if nplots == 1:
 60        ax = ax[0]
 61    plt.close(fig)
 62    return fig
 63
 64
 65def _ComputeBarcode(data: torch.Tensor,
 66                    hom_type: str,
 67                    coefs_type: str):
 68    if hom_type == "standard":
 69        VR = VietorisRipsPersistence(
 70            homology_dimensions=[0], collapse_edges=True, coeff=int(coefs_type))
 71    elif hom_type == "sparse":
 72        VR = SparseRipsPersistence(homology_dimensions=[0], coeff=int(coefs_type))
 73    elif hom_type == "weak":
 74        VR = WeakAlphaPersistence(
 75            homology_dimensions=[0], collapse_edges=True, coeff=int(coefs_type))
 76    else:
 77        assert False, "hom_type must be one of: \"standard\", \"sparse\", \"weak\"!"
 78
 79    if len(data.shape) > 2:
 80        data = torch.nn.Flatten()(data)
 81    data = data.reshape(1, *data.shape)
 82    diagrams = VR.fit_transform(data)
 83    plot = VR.plot(diagrams)
 84    barcode = diagram_to_barcode(plot)
 85    return plot_barcode(barcode)
 86
 87
 88def InnerNetspaceHomologies(model: torch.nn.Module,
 89                            x: torch.Tensor,
 90                            layer: str,
 91                            hom_type: str,
 92                            coefs_type: str):
 93    act = GetActivation(model, x, layer)
 94    plot = _ComputeBarcode(act, hom_type, coefs_type)
 95    return plot
 96
 97
 98def InnerNetspaceHomologiesExperimental(model: torch.nn.Module,
 99                                        x: torch.Tensor,
100                                        layer: str,
101                                        dimensions: List[int] = [0],
102                                        make_barplot: bool = True,
103                                        rm_empty: bool = True):
104    act = GetActivation(model, x, layer)
105    act = act.reshape(1, *act.shape)
106    # Dimensions must not be outside layer dimensionality
107    N = act.shape[-1]
108    dimensions = [i if i >= 0 else N + i for i in dimensions]
109    dimensions = [i for i in dimensions if ((i >= 0) and (i < N))]
110    dimensions = list(set(dimensions))
111    VR = VietorisRipsPersistence(homology_dimensions=dimensions,
112                                 collapse_edges=True)
113    diagrams = VR.fit_transform(act)
114    plot = VR.plot(diagrams)
115    if make_barplot:
116        barcode = diagram_to_barcode(plot)
117        if rm_empty:
118            barcode = {key: val for key, val in barcode.items() if len(val) > 0}
119        return plot_barcode(barcode)
120    else:
121        return plot
def GetActivation(model: torch.nn.modules.module.Module, x: torch.Tensor, layer: str):
 9def GetActivation(model: torch.nn.Module,
10                  x: torch.Tensor,
11                  layer: str):
12    activation = {}
13
14    def getActivation(name):
15        def hook(model, input, output):
16            activation[name] = output.detach()
17
18        return hook
19
20    h1 = getattr(model, layer).register_forward_hook(getActivation(layer))
21    model.forward(x)
22    h1.remove()
23    return activation[layer]
def diagram_to_barcode(plot):
26def diagram_to_barcode(plot):
27    data = plot['data']
28    homologies = {}
29    for h in data:
30        if h['name'] is None:
31            continue
32        homologies[h['name']] = list(zip(h['x'], h['y']))
33
34    for h in homologies.keys():
35        homologies[h] = sorted(homologies[h], key=lambda x: x[0])
36    return homologies
def plot_barcode(barcode: Dict):
39def plot_barcode(barcode: Dict):
40    homologies = list(barcode.keys())
41    nplots = len(homologies)
42    fig, ax = plt.subplots(nplots, figsize=(15, min(10, 5 * nplots)))
43
44    if nplots == 1:
45        ax = [ax]
46
47    for i in range(nplots):
48        name = homologies[i]
49        ax[i].set_title(name)
50        ax[i].set_ylim([-0.05, 1.05])
51        bars = barcode[name]
52        n = len(bars)
53        for j in range(n):
54            bar = bars[j]
55            ax[i].plot([bar[0], bar[1]], [j / n, j / n], color='black')
56        labels = ["" for _ in range(len(ax[i].get_yticklabels()))]
57        ax[i].set_yticks(ax[i].get_yticks())
58        ax[i].set_yticklabels(labels)
59
60    if nplots == 1:
61        ax = ax[0]
62    plt.close(fig)
63    return fig
def InnerNetspaceHomologies( model: torch.nn.modules.module.Module, x: torch.Tensor, layer: str, hom_type: str, coefs_type: str):
89def InnerNetspaceHomologies(model: torch.nn.Module,
90                            x: torch.Tensor,
91                            layer: str,
92                            hom_type: str,
93                            coefs_type: str):
94    act = GetActivation(model, x, layer)
95    plot = _ComputeBarcode(act, hom_type, coefs_type)
96    return plot
def InnerNetspaceHomologiesExperimental( model: torch.nn.modules.module.Module, x: torch.Tensor, layer: str, dimensions: List[int] = [0], make_barplot: bool = True, rm_empty: bool = True):
 99def InnerNetspaceHomologiesExperimental(model: torch.nn.Module,
100                                        x: torch.Tensor,
101                                        layer: str,
102                                        dimensions: List[int] = [0],
103                                        make_barplot: bool = True,
104                                        rm_empty: bool = True):
105    act = GetActivation(model, x, layer)
106    act = act.reshape(1, *act.shape)
107    # Dimensions must not be outside layer dimensionality
108    N = act.shape[-1]
109    dimensions = [i if i >= 0 else N + i for i in dimensions]
110    dimensions = [i for i in dimensions if ((i >= 0) and (i < N))]
111    dimensions = list(set(dimensions))
112    VR = VietorisRipsPersistence(homology_dimensions=dimensions,
113                                 collapse_edges=True)
114    diagrams = VR.fit_transform(act)
115    plot = VR.plot(diagrams)
116    if make_barplot:
117        barcode = diagram_to_barcode(plot)
118        if rm_empty:
119            barcode = {key: val for key, val in barcode.items() if len(val) > 0}
120        return plot_barcode(barcode)
121    else:
122        return plot