eXNN.InnerNeuralViz.hook

 1import torch.nn as nn
 2
 3
 4class Hook:
 5    def __init__(self, m: nn.Module):
 6        self.module = m
 7        self.fwd = None
 8        self.bwd = None
 9
10        def fwd_hook(m, i, o):
11            self.fwd = o
12
13        def bwd_hook(m, i, o):
14            self.bwd = o[0]
15
16        self.module.register_forward_hook(fwd_hook)
17        self.module.register_backward_hook(bwd_hook)
18
19    def clear(self):
20        self.fwd = None
21        self.bwd = None
22
23
24def _get_module_by_name(model, name):
25    for n, m in model.named_modules():
26        if n == name:
27            return m
28    raise Exception(f'Model does not contain submodule {name}')
29
30
31def get_hook(model, layer_name) -> Hook:
32    layer = _get_module_by_name(model, layer_name)
33    return Hook(layer)
class Hook:
 5class Hook:
 6    def __init__(self, m: nn.Module):
 7        self.module = m
 8        self.fwd = None
 9        self.bwd = None
10
11        def fwd_hook(m, i, o):
12            self.fwd = o
13
14        def bwd_hook(m, i, o):
15            self.bwd = o[0]
16
17        self.module.register_forward_hook(fwd_hook)
18        self.module.register_backward_hook(bwd_hook)
19
20    def clear(self):
21        self.fwd = None
22        self.bwd = None
Hook(m: torch.nn.modules.module.Module)
 6    def __init__(self, m: nn.Module):
 7        self.module = m
 8        self.fwd = None
 9        self.bwd = None
10
11        def fwd_hook(m, i, o):
12            self.fwd = o
13
14        def bwd_hook(m, i, o):
15            self.bwd = o[0]
16
17        self.module.register_forward_hook(fwd_hook)
18        self.module.register_backward_hook(bwd_hook)
def clear(self):
20    def clear(self):
21        self.fwd = None
22        self.bwd = None
def get_hook(model, layer_name) -> eXNN.InnerNeuralViz.hook.Hook:
32def get_hook(model, layer_name) -> Hook:
33    layer = _get_module_by_name(model, layer_name)
34    return Hook(layer)