eXNN.NetBayesianization.wrap

 1from typing import Optional
 2import torch
 3import torch.nn as nn
 4import copy
 5import torch.optim
 6from torch.distributions import Beta
 7
 8# calculate mean and std after applying bayesian
 9
10
11class NetworkBayes(nn.Module):
12    def __init__(self,
13                 model: nn.Module,
14                 dropout_p: float):
15
16        super(NetworkBayes, self).__init__()
17        self.model = model
18        self.dropout_p = dropout_p
19
20    def mean_forward(self,
21                     data: torch.Tensor,
22                     n_iter: int):
23
24        results = []
25        for i in range(n_iter):
26            model_copy = copy.deepcopy(self.model)
27            state_dict = model_copy.state_dict()
28            state_dict_v2 = copy.deepcopy(state_dict)
29            for key, value in state_dict_v2.items():
30                if 'weight' in key:
31                    output = nn.functional.dropout(value, self.dropout_p, training=True)
32                    state_dict_v2[key] = output
33            model_copy.load_state_dict(state_dict_v2, strict=True)
34            output = model_copy(data)
35            results.append(output)
36
37        results = torch.stack(results, dim=1)
38        results = torch.stack([
39            torch.mean(results, dim=1),
40            torch.std(results, dim=1)
41        ], dim=0)
42        return results
43
44
45# calculate mean and std after applying bayesian with beta distribution
46class NetworkBayesBeta(nn.Module):
47    def __init__(self,
48                 model: torch.nn.Module,
49                 alpha: float,
50                 beta: float):
51
52        super(NetworkBayesBeta, self).__init__()
53        self.model = model
54        self.alpha = alpha
55        self.beta = beta
56
57    def mean_forward(self,
58                     data: torch.Tensor,
59                     n_iter: int):
60
61        results = []
62        m = Beta(torch.tensor(self.alpha), torch.tensor(self.beta))
63        for i in range(n_iter):
64            p = m.sample()
65            model_copy = copy.deepcopy(self.model)
66            state_dict = model_copy.state_dict()
67            state_dict_v2 = copy.deepcopy(state_dict)
68            for key, value in state_dict_v2.items():
69                if 'weight' in key:
70                    output = nn.functional.dropout(value, p, training=True)
71                    state_dict_v2[key] = output
72            model_copy.load_state_dict(state_dict_v2, strict=True)
73            output = model_copy(data)
74            results.append(output)
75        results = torch.stack(results, dim=1)
76        results = torch.stack([
77            torch.mean(results, dim=1),
78            torch.std(results, dim=1)
79        ], dim=0)
80        return results
81
82
83def create_bayesian_wrapper(model: torch.nn.Module,
84                            mode: Optional[str] = 'basic',
85                            p: Optional[float] = None,
86                            a: Optional[float] = None,
87                            b: Optional[float] = None) -> torch.nn.Module:
88    if mode == 'basic':
89        net = NetworkBayes(model, p)
90
91    elif mode == 'beta':
92        net = NetworkBayesBeta(model, a, b)
93
94    return net
class NetworkBayes(torch.nn.modules.module.Module):
12class NetworkBayes(nn.Module):
13    def __init__(self,
14                 model: nn.Module,
15                 dropout_p: float):
16
17        super(NetworkBayes, self).__init__()
18        self.model = model
19        self.dropout_p = dropout_p
20
21    def mean_forward(self,
22                     data: torch.Tensor,
23                     n_iter: int):
24
25        results = []
26        for i in range(n_iter):
27            model_copy = copy.deepcopy(self.model)
28            state_dict = model_copy.state_dict()
29            state_dict_v2 = copy.deepcopy(state_dict)
30            for key, value in state_dict_v2.items():
31                if 'weight' in key:
32                    output = nn.functional.dropout(value, self.dropout_p, training=True)
33                    state_dict_v2[key] = output
34            model_copy.load_state_dict(state_dict_v2, strict=True)
35            output = model_copy(data)
36            results.append(output)
37
38        results = torch.stack(results, dim=1)
39        results = torch.stack([
40            torch.mean(results, dim=1),
41            torch.std(results, dim=1)
42        ], dim=0)
43        return results

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

NetworkBayes(model: torch.nn.modules.module.Module, dropout_p: float)
13    def __init__(self,
14                 model: nn.Module,
15                 dropout_p: float):
16
17        super(NetworkBayes, self).__init__()
18        self.model = model
19        self.dropout_p = dropout_p

Initializes internal Module state, shared by both nn.Module and ScriptModule.

def mean_forward(self, data: torch.Tensor, n_iter: int):
21    def mean_forward(self,
22                     data: torch.Tensor,
23                     n_iter: int):
24
25        results = []
26        for i in range(n_iter):
27            model_copy = copy.deepcopy(self.model)
28            state_dict = model_copy.state_dict()
29            state_dict_v2 = copy.deepcopy(state_dict)
30            for key, value in state_dict_v2.items():
31                if 'weight' in key:
32                    output = nn.functional.dropout(value, self.dropout_p, training=True)
33                    state_dict_v2[key] = output
34            model_copy.load_state_dict(state_dict_v2, strict=True)
35            output = model_copy(data)
36            results.append(output)
37
38        results = torch.stack(results, dim=1)
39        results = torch.stack([
40            torch.mean(results, dim=1),
41            torch.std(results, dim=1)
42        ], dim=0)
43        return results
Inherited Members
torch.nn.modules.module.Module
dump_patches
forward
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
state_dict
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
class NetworkBayesBeta(torch.nn.modules.module.Module):
47class NetworkBayesBeta(nn.Module):
48    def __init__(self,
49                 model: torch.nn.Module,
50                 alpha: float,
51                 beta: float):
52
53        super(NetworkBayesBeta, self).__init__()
54        self.model = model
55        self.alpha = alpha
56        self.beta = beta
57
58    def mean_forward(self,
59                     data: torch.Tensor,
60                     n_iter: int):
61
62        results = []
63        m = Beta(torch.tensor(self.alpha), torch.tensor(self.beta))
64        for i in range(n_iter):
65            p = m.sample()
66            model_copy = copy.deepcopy(self.model)
67            state_dict = model_copy.state_dict()
68            state_dict_v2 = copy.deepcopy(state_dict)
69            for key, value in state_dict_v2.items():
70                if 'weight' in key:
71                    output = nn.functional.dropout(value, p, training=True)
72                    state_dict_v2[key] = output
73            model_copy.load_state_dict(state_dict_v2, strict=True)
74            output = model_copy(data)
75            results.append(output)
76        results = torch.stack(results, dim=1)
77        results = torch.stack([
78            torch.mean(results, dim=1),
79            torch.std(results, dim=1)
80        ], dim=0)
81        return results

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

NetworkBayesBeta(model: torch.nn.modules.module.Module, alpha: float, beta: float)
48    def __init__(self,
49                 model: torch.nn.Module,
50                 alpha: float,
51                 beta: float):
52
53        super(NetworkBayesBeta, self).__init__()
54        self.model = model
55        self.alpha = alpha
56        self.beta = beta

Initializes internal Module state, shared by both nn.Module and ScriptModule.

def mean_forward(self, data: torch.Tensor, n_iter: int):
58    def mean_forward(self,
59                     data: torch.Tensor,
60                     n_iter: int):
61
62        results = []
63        m = Beta(torch.tensor(self.alpha), torch.tensor(self.beta))
64        for i in range(n_iter):
65            p = m.sample()
66            model_copy = copy.deepcopy(self.model)
67            state_dict = model_copy.state_dict()
68            state_dict_v2 = copy.deepcopy(state_dict)
69            for key, value in state_dict_v2.items():
70                if 'weight' in key:
71                    output = nn.functional.dropout(value, p, training=True)
72                    state_dict_v2[key] = output
73            model_copy.load_state_dict(state_dict_v2, strict=True)
74            output = model_copy(data)
75            results.append(output)
76        results = torch.stack(results, dim=1)
77        results = torch.stack([
78            torch.mean(results, dim=1),
79            torch.std(results, dim=1)
80        ], dim=0)
81        return results
Inherited Members
torch.nn.modules.module.Module
dump_patches
forward
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
state_dict
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
def create_bayesian_wrapper( model: torch.nn.modules.module.Module, mode: Union[str, NoneType] = 'basic', p: Union[float, NoneType] = None, a: Union[float, NoneType] = None, b: Union[float, NoneType] = None) -> torch.nn.modules.module.Module:
84def create_bayesian_wrapper(model: torch.nn.Module,
85                            mode: Optional[str] = 'basic',
86                            p: Optional[float] = None,
87                            a: Optional[float] = None,
88                            b: Optional[float] = None) -> torch.nn.Module:
89    if mode == 'basic':
90        net = NetworkBayes(model, p)
91
92    elif mode == 'beta':
93        net = NetworkBayesBeta(model, a, b)
94
95    return net