eXNN.InnerNeuralViz.api
1import math 2import matplotlib 3import matplotlib.pyplot as plt 4import numpy as np 5import torch 6from typing import Dict, List, Optional 7from sklearn.decomposition import PCA 8import umap 9from eXNN.InnerNeuralViz.hook import get_hook 10 11 12def _plot(embedding, labels): 13 fig, ax = plt.subplots() 14 ax.scatter(x=embedding[:, 0], y=embedding[:, 1], c=labels) 15 ax.axis('Off') 16 plt.close() 17 return fig 18 19 20def ReduceDim(data: torch.Tensor, 21 mode: str) -> np.ndarray: 22 """This function reduces data dimensionality to 2 dimensions. 23 24 Args: 25 data (torch.Tensor): input data of shape NxC1x...xCk, 26 where N is the number of data points, 27 C1,...,Ck are dimensions of each data point 28 mode (str): dimensionality reduction mode (`umap` or `pca`) 29 30 Raises: 31 ValueError: returned if unsupported mode is provided 32 33 Returns: 34 np.ndarray: data projected on a 2d space, of shape Nx2 35 """ 36 37 data = data.detach().cpu().numpy().reshape((len(data), -1)) 38 if mode == 'pca': 39 return PCA(n_components=2).fit_transform(data) 40 elif mode == 'umap': 41 return umap.UMAP().fit_transform(data) 42 else: 43 raise ValueError(f'Unsupported mode: `{mode}`') 44 45 46def VisualizeNetSpace(model: torch.nn.Module, 47 mode: str, 48 data: torch.Tensor, 49 layers: Optional[List[str]] = None, 50 labels: Optional[torch.Tensor] = None, 51 chunk_size: Optional[int] = None)\ 52 -> Dict[str, matplotlib.figure.Figure]: 53 """This function visulizes data latent representations on neural network layers. 54 55 Args: 56 model (torch.nn.Module): neural network 57 mode (str): dimensionality reduction mode (`umap` or `pca`) 58 data (torch.Tensor): input data of shape NxC1x...xCk, 59 where N is the number of data points, 60 C1,...,Ck are dimensions of each data point 61 layers (Optional[List[str]], optional): list of layers for visualization. 62 Defaults to None. If None, visualization for all layers is performed 63 labels (Optional[torch.Tensor], optional): data labels (colors). 64 Defaults to None. If None, all points are visualized with the same color 65 chunk_size (Optional[int], optional): batch size for data processing. 66 Defaults to None. If None, all data is processed in one batch 67 68 Returns: 69 Dict[str, matplotlib.figure.Figure]: dictionary with latent 70 representations visualization for each layer 71 """ 72 73 if layers is None: 74 layers = [_[0] for _ in model.named_children()] 75 if labels is not None: 76 labels = labels.detach().cpu().numpy() 77 hooks = {layer: get_hook(model, layer) for layer in layers} 78 if chunk_size is None: 79 with torch.no_grad(): 80 out = model(data) 81 visualizations = {'input': _plot(ReduceDim(data, mode), labels)} 82 for layer in layers: 83 visualizations[layer] = _plot(ReduceDim(hooks[layer].fwd, mode), labels) 84 return visualizations 85 else: 86 representations = {layer: [] for layer in layers} 87 for i in range(math.ceil(len(data) / chunk_size)): 88 with torch.no_grad(): 89 out = model(data[i * chunk_size:(i + 1) * chunk_size]) 90 for layer in layers: 91 representations[layer].append(hooks[layer].fwd.detach().cpu()) 92 visualizations = {'input': _plot(ReduceDim(data, mode), labels)} 93 for layer in layers: 94 layer_reprs = torch.cat(representations[layer], dim=0) 95 visualizations[layer] = _plot(ReduceDim(layer_reprs, mode), labels) 96 return visualizations 97 98 99def get_random_input(dims: List[int]) -> torch.Tensor: 100 """This function generates uniformly distributed tensor of given shape. 101 102 Args: 103 dims (List[int]): required data shape 104 105 Returns: 106 torch.Tensor: uniformly distributed tensor of given shape 107 """ 108 return torch.rand(size=dims)
21def ReduceDim(data: torch.Tensor, 22 mode: str) -> np.ndarray: 23 """This function reduces data dimensionality to 2 dimensions. 24 25 Args: 26 data (torch.Tensor): input data of shape NxC1x...xCk, 27 where N is the number of data points, 28 C1,...,Ck are dimensions of each data point 29 mode (str): dimensionality reduction mode (`umap` or `pca`) 30 31 Raises: 32 ValueError: returned if unsupported mode is provided 33 34 Returns: 35 np.ndarray: data projected on a 2d space, of shape Nx2 36 """ 37 38 data = data.detach().cpu().numpy().reshape((len(data), -1)) 39 if mode == 'pca': 40 return PCA(n_components=2).fit_transform(data) 41 elif mode == 'umap': 42 return umap.UMAP().fit_transform(data) 43 else: 44 raise ValueError(f'Unsupported mode: `{mode}`')
This function reduces data dimensionality to 2 dimensions.
Args:
data (torch.Tensor): input data of shape NxC1x...xCk,
where N is the number of data points,
C1,...,Ck are dimensions of each data point
mode (str): dimensionality reduction mode (umap
or pca
)
Raises: ValueError: returned if unsupported mode is provided
Returns: np.ndarray: data projected on a 2d space, of shape Nx2
47def VisualizeNetSpace(model: torch.nn.Module, 48 mode: str, 49 data: torch.Tensor, 50 layers: Optional[List[str]] = None, 51 labels: Optional[torch.Tensor] = None, 52 chunk_size: Optional[int] = None)\ 53 -> Dict[str, matplotlib.figure.Figure]: 54 """This function visulizes data latent representations on neural network layers. 55 56 Args: 57 model (torch.nn.Module): neural network 58 mode (str): dimensionality reduction mode (`umap` or `pca`) 59 data (torch.Tensor): input data of shape NxC1x...xCk, 60 where N is the number of data points, 61 C1,...,Ck are dimensions of each data point 62 layers (Optional[List[str]], optional): list of layers for visualization. 63 Defaults to None. If None, visualization for all layers is performed 64 labels (Optional[torch.Tensor], optional): data labels (colors). 65 Defaults to None. If None, all points are visualized with the same color 66 chunk_size (Optional[int], optional): batch size for data processing. 67 Defaults to None. If None, all data is processed in one batch 68 69 Returns: 70 Dict[str, matplotlib.figure.Figure]: dictionary with latent 71 representations visualization for each layer 72 """ 73 74 if layers is None: 75 layers = [_[0] for _ in model.named_children()] 76 if labels is not None: 77 labels = labels.detach().cpu().numpy() 78 hooks = {layer: get_hook(model, layer) for layer in layers} 79 if chunk_size is None: 80 with torch.no_grad(): 81 out = model(data) 82 visualizations = {'input': _plot(ReduceDim(data, mode), labels)} 83 for layer in layers: 84 visualizations[layer] = _plot(ReduceDim(hooks[layer].fwd, mode), labels) 85 return visualizations 86 else: 87 representations = {layer: [] for layer in layers} 88 for i in range(math.ceil(len(data) / chunk_size)): 89 with torch.no_grad(): 90 out = model(data[i * chunk_size:(i + 1) * chunk_size]) 91 for layer in layers: 92 representations[layer].append(hooks[layer].fwd.detach().cpu()) 93 visualizations = {'input': _plot(ReduceDim(data, mode), labels)} 94 for layer in layers: 95 layer_reprs = torch.cat(representations[layer], dim=0) 96 visualizations[layer] = _plot(ReduceDim(layer_reprs, mode), labels) 97 return visualizations
This function visulizes data latent representations on neural network layers.
Args:
model (torch.nn.Module): neural network
mode (str): dimensionality reduction mode (umap
or pca
)
data (torch.Tensor): input data of shape NxC1x...xCk,
where N is the number of data points,
C1,...,Ck are dimensions of each data point
layers (Optional[List[str]], optional): list of layers for visualization.
Defaults to None. If None, visualization for all layers is performed
labels (Optional[torch.Tensor], optional): data labels (colors).
Defaults to None. If None, all points are visualized with the same color
chunk_size (Optional[int], optional): batch size for data processing.
Defaults to None. If None, all data is processed in one batch
Returns: Dict[str, matplotlib.figure.Figure]: dictionary with latent representations visualization for each layer
100def get_random_input(dims: List[int]) -> torch.Tensor: 101 """This function generates uniformly distributed tensor of given shape. 102 103 Args: 104 dims (List[int]): required data shape 105 106 Returns: 107 torch.Tensor: uniformly distributed tensor of given shape 108 """ 109 return torch.rand(size=dims)
This function generates uniformly distributed tensor of given shape.
Args: dims (List[int]): required data shape
Returns: torch.Tensor: uniformly distributed tensor of given shape