-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #104 from synsense/dev/nir
Dev/nir
- Loading branch information
Showing
6 changed files
with
324 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
coverage: | ||
status: | ||
project: | ||
default: | ||
# basic | ||
target: 60% | ||
threshold: 0% | ||
patch: | ||
default: | ||
# basic | ||
target: 60% | ||
threshold: 0% |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,5 @@ | ||
pbr | ||
numpy | ||
torch>=1.8 | ||
nir | ||
nirtorch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,224 @@ | ||
from functools import partial | ||
from typing import Optional, Tuple, Union | ||
|
||
import nir | ||
import nirtorch | ||
import numpy as np | ||
import torch | ||
from torch import nn | ||
|
||
import sinabs.layers as sl | ||
|
||
|
||
def _as_pair(x) -> Tuple[int, int]: | ||
try: | ||
if len(x) == 1: | ||
return (x[0], x[0]) | ||
elif len(x) >= 2: | ||
return tuple(x) | ||
else: | ||
raise ValueError() | ||
except TypeError: | ||
return x, x | ||
|
||
|
||
def _import_sinabs_module( | ||
node: nir.NIRNode, batch_size: int, num_timesteps: int | ||
) -> torch.nn.Module: | ||
if isinstance(node, nir.Affine): | ||
linear = nn.Linear( | ||
in_features=node.weight.shape[1], | ||
out_features=node.weight.shape[0], | ||
bias=True, | ||
) | ||
linear.weight.data = torch.tensor(node.weight).float() | ||
linear.bias.data = torch.tensor(node.bias).float() | ||
return linear | ||
|
||
elif isinstance(node, nir.Conv1d): | ||
conv = nn.Conv1d( | ||
in_channels=node.weight.shape[1], | ||
out_channels=node.weight.shape[0], | ||
kernel_size=node.weight.shape[2:], | ||
stride=node.stride, | ||
padding=node.padding, | ||
dilation=node.dilation, | ||
groups=node.groups, | ||
bias=True, | ||
) | ||
conv.weight.data = torch.tensor(node.weight).float() | ||
conv.bias.data = torch.tensor(node.bias).float() | ||
return conv | ||
|
||
elif isinstance(node, nir.Conv2d): | ||
conv = nn.Conv2d( | ||
in_channels=node.weight.shape[1], | ||
out_channels=node.weight.shape[0], | ||
kernel_size=node.weight.shape[2:], | ||
stride=node.stride, | ||
padding=node.padding, | ||
dilation=node.dilation, | ||
groups=node.groups, | ||
bias=True, | ||
) | ||
conv.weight.data = torch.tensor(node.weight).float() | ||
conv.bias.data = torch.tensor(node.bias).float() | ||
return conv | ||
|
||
elif isinstance(node, nir.LI): | ||
if node.v_leak.shape == torch.Size([]): | ||
node.v_leak = node.v_leak.unsqueeze(0) | ||
if node.r.shape == torch.Size([]): | ||
node.r = node.r.unsqueeze(0) | ||
if any(node.v_leak != 0): | ||
raise ValueError("`v_leak` must be 0") | ||
if any(node.r != 1): | ||
raise ValueError("`r` must be 1") | ||
# TODO check for norm_input | ||
return sl.ExpLeakSqueeze( | ||
tau_mem=node.tau, | ||
min_v_mem=None, | ||
num_timesteps=num_timesteps, | ||
batch_size=batch_size, | ||
norm_input=False, | ||
) | ||
|
||
elif isinstance(node, nir.IF): | ||
return sl.IAFSqueeze( | ||
min_v_mem=-node.v_threshold, | ||
num_timesteps=num_timesteps, | ||
batch_size=batch_size, | ||
spike_threshold=node.v_threshold, | ||
) | ||
|
||
elif isinstance(node, nir.LIF): | ||
if node.v_leak.shape == torch.Size([]): | ||
node.v_leak = node.v_leak.unsqueeze(0) | ||
if any(node.v_leak) != 0: | ||
raise ValueError("`v_leak` must be 0") | ||
# TODO check for norm_input | ||
return sl.LIFSqueeze( | ||
tau_mem=node.tau, | ||
min_v_mem=None, | ||
num_timesteps=num_timesteps, | ||
batch_size=batch_size, | ||
spike_threshold=node.v_threshold, | ||
tau_syn=None, | ||
norm_input=False, | ||
) | ||
elif isinstance(node, nir.SumPool2d): | ||
return sl.SumPool2d( | ||
kernel_size=tuple(node.kernel_size), stride=tuple(node.stride) | ||
) | ||
elif isinstance(node, nir.Flatten): | ||
start_dim = node.start_dim + 1 if node.start_dim >= 0 else node.start_dim | ||
end_dim = node.end_dim + 1 if node.end_dim >= 0 else node.end_dim | ||
return nn.Flatten(start_dim=start_dim, end_dim=end_dim) | ||
elif isinstance(node, nir.Input): | ||
return nn.Identity() | ||
elif isinstance(node, nir.Output): | ||
return nn.Identity() | ||
|
||
|
||
def from_nir( | ||
node: nir.NIRNode, batch_size: int = None, num_timesteps: int = None | ||
) -> torch.nn.Module: | ||
return nirtorch.load( | ||
node, | ||
partial( | ||
_import_sinabs_module, batch_size=batch_size, num_timesteps=num_timesteps | ||
), | ||
) | ||
|
||
|
||
def _extend_to_shape(x: Union[torch.Tensor, float], shape: Tuple) -> torch.Tensor: | ||
if x.shape == shape: | ||
return x | ||
elif x.shape == (1,) or x.dim() == 0: | ||
return torch.ones(*shape) * x | ||
else: | ||
raise ValueError(f"Not sure how to extend {x} to shape {shape}") | ||
|
||
|
||
def _extract_sinabs_module(module: torch.nn.Module) -> Optional[nir.NIRNode]: | ||
if type(module) in [sl.IAF, sl.IAFSqueeze]: | ||
layer_shape = module.v_mem.shape[1:] | ||
nir_node = nir.IF( | ||
r=torch.ones(*layer_shape), # Discard batch dim | ||
v_threshold=_extend_to_shape(module.spike_threshold.detach(), layer_shape), | ||
) | ||
return nir_node | ||
elif type(module) in [sl.LIF, sl.LIFSqueeze]: | ||
layer_shape = module.v_mem.shape[0] | ||
return nir.LIF( | ||
tau=module.tau_mem.detach(), | ||
v_threshold=module.spike_threshold.detach(), | ||
v_leak=torch.zeros_like(module.tau_mem.detach()), | ||
r=torch.ones_like(module.tau_mem.detach()), | ||
) | ||
elif type(module) in [sl.ExpLeak, sl.ExpLeakSqueeze]: | ||
return nir.LI( | ||
tau=module.tau_mem.detach(), | ||
v_leak=torch.zeros_like(module.tau_mem.detach()), | ||
r=torch.ones_like(module.tau_mem.detach()), | ||
) | ||
elif isinstance(module, torch.nn.Linear): | ||
if module.bias is None: # Add zero bias if none is present | ||
return nir.Affine( | ||
module.weight.detach(), torch.zeros(*module.weight.shape[:-1]) | ||
) | ||
else: | ||
return nir.Affine(module.weight.detach(), module.bias.detach()) | ||
elif isinstance(module, torch.nn.Conv1d): | ||
return nir.Conv1d( | ||
weight=module.weight.detach(), | ||
stride=module.stride, | ||
padding=module.padding, | ||
dilation=module.dilation, | ||
groups=module.groups, | ||
bias=module.bias.detach() | ||
if module.bias | ||
else torch.zeros((module.weight.shape[0])), | ||
) | ||
elif isinstance(module, torch.nn.Conv2d): | ||
return nir.Conv2d( | ||
input_shape=None, | ||
weight=module.weight.detach(), | ||
stride=module.stride, | ||
padding=module.padding, | ||
dilation=module.dilation, | ||
groups=module.groups, | ||
bias=module.bias.detach() | ||
if isinstance(module.bias, torch.Tensor) | ||
else torch.zeros((module.weight.shape[0])), | ||
) | ||
elif isinstance(module, sl.SumPool2d): | ||
return nir.SumPool2d( | ||
kernel_size=_as_pair(module.kernel_size), # (Height, Width) | ||
stride=_as_pair( | ||
module.kernel_size if module.stride is None else module.stride | ||
), # (Height, width) | ||
padding=(0, 0), # (Height, width) | ||
) | ||
elif isinstance(module, nn.Flatten): | ||
# Getting rid of the batch dimension for NIR | ||
start_dim = module.start_dim - 1 if module.start_dim > 0 else module.start_dim | ||
end_dim = module.end_dim - 1 if module.end_dim > 0 else module.end_dim | ||
return nir.Flatten( | ||
input_type=None, | ||
start_dim=start_dim, | ||
end_dim=end_dim, | ||
) | ||
raise NotImplementedError(f"Module {type(module)} not supported") | ||
|
||
|
||
def to_nir( | ||
module: torch.nn.Module, sample_data: torch.Tensor, model_name: str = "model" | ||
) -> nir.NIRNode: | ||
return nirtorch.extract_nir_graph( | ||
module, | ||
_extract_sinabs_module, | ||
sample_data, | ||
model_name=model_name, | ||
ignore_dims=[0], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
import nir | ||
import torch | ||
import torch.nn as nn | ||
|
||
import sinabs.layers as sl | ||
from sinabs import from_nir, to_nir | ||
|
||
|
||
def test_iaf(): | ||
batch_size = 2 | ||
iaf = sl.IAFSqueeze(batch_size=batch_size) | ||
graph = to_nir(iaf, torch.randn(batch_size, 10)) | ||
converted = from_nir(graph, batch_size=batch_size) | ||
|
||
assert len(graph.nodes) == 1 + 2 | ||
assert isinstance(graph.nodes["model"], nir.IF) | ||
assert len(graph.edges) == 0 + 2 | ||
assert iaf.batch_size == converted.model.batch_size | ||
|
||
|
||
def test_conv2d(): | ||
batch_size = 2 | ||
conv2d = nn.Conv2d(1, 3, 2) | ||
graph = to_nir(conv2d, torch.randn(batch_size, 1, 32, 32)) | ||
converted = from_nir(graph, batch_size=batch_size) | ||
|
||
assert len(graph.nodes) == 1 + 2 | ||
assert isinstance(graph.nodes["model"], nir.Conv2d) | ||
assert len(graph.edges) == 0 + 2 | ||
assert conv2d.kernel_size == converted.model.kernel_size | ||
assert conv2d.stride == converted.model.stride | ||
assert conv2d.padding == converted.model.padding | ||
assert conv2d.dilation == converted.model.dilation | ||
|
||
|
||
def test_from_sequential_to_nir(): | ||
m = nn.Sequential( | ||
torch.nn.Linear(10, 2), | ||
sl.ExpLeak(tau_mem=10.0), | ||
sl.LIF(tau_mem=10.0), | ||
torch.nn.Linear(2, 1), | ||
) | ||
graph = to_nir(m, torch.randn(1, 10)) | ||
assert len(graph.nodes) == 4 + 2 | ||
assert isinstance(graph.nodes["0"], nir.Affine) | ||
assert isinstance(graph.nodes["1"], nir.LI) | ||
assert isinstance(graph.nodes["2"], nir.LIF) | ||
assert isinstance(graph.nodes["3"], nir.Affine) | ||
assert len(graph.edges) == 3 + 2 | ||
|
||
|
||
def test_from_linear_to_nir(): | ||
in_features = 2 | ||
out_features = 3 | ||
m = torch.nn.Linear(in_features, out_features, bias=False) | ||
m2 = torch.nn.Linear(in_features, out_features, bias=True) | ||
graph = to_nir(m, torch.randn(1, in_features)) | ||
assert len(graph.nodes) == 1 + 2 | ||
assert graph.nodes["model"].weight.shape == (out_features, in_features) | ||
assert graph.nodes["model"].bias.shape == m2.bias.shape | ||
|
||
|
||
def test_from_nir_to_sequential(): | ||
batch_size = 4 | ||
|
||
orig_model = nn.Sequential( | ||
torch.nn.Linear(10, 2), | ||
sl.ExpLeakSqueeze(tau_mem=10.0, batch_size=batch_size), | ||
sl.LIFSqueeze(tau_mem=10.0, batch_size=batch_size), | ||
torch.nn.Linear(2, 1), | ||
) | ||
nir_graph = to_nir(orig_model, torch.randn(batch_size, 10)) | ||
|
||
converted_model = from_nir(nir_graph, batch_size=batch_size) | ||
converted_modules = list(converted_model.children()) | ||
assert len(orig_model) + 2 == len( | ||
converted_modules | ||
) # Addition of input and output modules | ||
torch.testing.assert_allclose(orig_model[0].weight, converted_modules[1].weight) | ||
torch.testing.assert_allclose(orig_model[0].bias, converted_modules[1].bias) | ||
assert type(orig_model[1]) == type(converted_modules[2]) | ||
assert type(orig_model[2]) == type(converted_modules[3]) | ||
torch.testing.assert_allclose(orig_model[3].weight, converted_modules[4].weight) | ||
torch.testing.assert_allclose(orig_model[3].bias, converted_modules[4].bias) |