eXNN.InnerNeuralViz.hook
1import torch.nn as nn 2 3 4class Hook: 5 def __init__(self, m: nn.Module): 6 self.module = m 7 self.fwd = None 8 self.bwd = None 9 10 def fwd_hook(m, i, o): 11 self.fwd = o 12 13 def bwd_hook(m, i, o): 14 self.bwd = o[0] 15 16 self.module.register_forward_hook(fwd_hook) 17 self.module.register_backward_hook(bwd_hook) 18 19 def clear(self): 20 self.fwd = None 21 self.bwd = None 22 23 24def _get_module_by_name(model, name): 25 for n, m in model.named_modules(): 26 if n == name: 27 return m 28 raise Exception(f'Model does not contain submodule {name}') 29 30 31def get_hook(model, layer_name) -> Hook: 32 layer = _get_module_by_name(model, layer_name) 33 return Hook(layer)
class
Hook:
5class Hook: 6 def __init__(self, m: nn.Module): 7 self.module = m 8 self.fwd = None 9 self.bwd = None 10 11 def fwd_hook(m, i, o): 12 self.fwd = o 13 14 def bwd_hook(m, i, o): 15 self.bwd = o[0] 16 17 self.module.register_forward_hook(fwd_hook) 18 self.module.register_backward_hook(bwd_hook) 19 20 def clear(self): 21 self.fwd = None 22 self.bwd = None