Skip to content

Commit

Permalink
Merge pull request #104 from synsense/dev/nir
Browse files Browse the repository at this point in the history
Dev/nir
  • Loading branch information
ssinhaleite authored Dec 7, 2023
2 parents a2e28bc + 336fa9d commit 45448fc
Show file tree
Hide file tree
Showing 6 changed files with 324 additions and 1 deletion.
2 changes: 1 addition & 1 deletion .github/workflows/ci-pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ jobs:
coverage run -m pytest tests
coverage xml
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v2
uses: codecov/codecov-action@v3

documentation:
needs: multitest
Expand Down
12 changes: 12 additions & 0 deletions codecov.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
coverage:
status:
project:
default:
# basic
target: 60%
threshold: 0%
patch:
default:
# basic
target: 60%
threshold: 0%
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
pbr
numpy
torch>=1.8
nir
nirtorch
1 change: 1 addition & 0 deletions sinabs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@
from . import conversion, utils
from .from_torch import from_model
from .network import Network
from .nir import *
from .synopcounter import SNNAnalyzer, SynOpCounter
from .utils import reset_states, set_batch_size, zero_grad
224 changes: 224 additions & 0 deletions sinabs/nir.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
from functools import partial
from typing import Optional, Tuple, Union

import nir
import nirtorch
import numpy as np
import torch
from torch import nn

import sinabs.layers as sl


def _as_pair(x) -> Tuple[int, int]:
try:
if len(x) == 1:
return (x[0], x[0])
elif len(x) >= 2:
return tuple(x)
else:
raise ValueError()
except TypeError:
return x, x


def _import_sinabs_module(
node: nir.NIRNode, batch_size: int, num_timesteps: int
) -> torch.nn.Module:
if isinstance(node, nir.Affine):
linear = nn.Linear(
in_features=node.weight.shape[1],
out_features=node.weight.shape[0],
bias=True,
)
linear.weight.data = torch.tensor(node.weight).float()
linear.bias.data = torch.tensor(node.bias).float()
return linear

elif isinstance(node, nir.Conv1d):
conv = nn.Conv1d(
in_channels=node.weight.shape[1],
out_channels=node.weight.shape[0],
kernel_size=node.weight.shape[2:],
stride=node.stride,
padding=node.padding,
dilation=node.dilation,
groups=node.groups,
bias=True,
)
conv.weight.data = torch.tensor(node.weight).float()
conv.bias.data = torch.tensor(node.bias).float()
return conv

elif isinstance(node, nir.Conv2d):
conv = nn.Conv2d(
in_channels=node.weight.shape[1],
out_channels=node.weight.shape[0],
kernel_size=node.weight.shape[2:],
stride=node.stride,
padding=node.padding,
dilation=node.dilation,
groups=node.groups,
bias=True,
)
conv.weight.data = torch.tensor(node.weight).float()
conv.bias.data = torch.tensor(node.bias).float()
return conv

elif isinstance(node, nir.LI):
if node.v_leak.shape == torch.Size([]):
node.v_leak = node.v_leak.unsqueeze(0)
if node.r.shape == torch.Size([]):
node.r = node.r.unsqueeze(0)
if any(node.v_leak != 0):
raise ValueError("`v_leak` must be 0")
if any(node.r != 1):
raise ValueError("`r` must be 1")
# TODO check for norm_input
return sl.ExpLeakSqueeze(
tau_mem=node.tau,
min_v_mem=None,
num_timesteps=num_timesteps,
batch_size=batch_size,
norm_input=False,
)

elif isinstance(node, nir.IF):
return sl.IAFSqueeze(
min_v_mem=-node.v_threshold,
num_timesteps=num_timesteps,
batch_size=batch_size,
spike_threshold=node.v_threshold,
)

elif isinstance(node, nir.LIF):
if node.v_leak.shape == torch.Size([]):
node.v_leak = node.v_leak.unsqueeze(0)
if any(node.v_leak) != 0:
raise ValueError("`v_leak` must be 0")
# TODO check for norm_input
return sl.LIFSqueeze(
tau_mem=node.tau,
min_v_mem=None,
num_timesteps=num_timesteps,
batch_size=batch_size,
spike_threshold=node.v_threshold,
tau_syn=None,
norm_input=False,
)
elif isinstance(node, nir.SumPool2d):
return sl.SumPool2d(
kernel_size=tuple(node.kernel_size), stride=tuple(node.stride)
)
elif isinstance(node, nir.Flatten):
start_dim = node.start_dim + 1 if node.start_dim >= 0 else node.start_dim
end_dim = node.end_dim + 1 if node.end_dim >= 0 else node.end_dim
return nn.Flatten(start_dim=start_dim, end_dim=end_dim)
elif isinstance(node, nir.Input):
return nn.Identity()
elif isinstance(node, nir.Output):
return nn.Identity()


def from_nir(
node: nir.NIRNode, batch_size: int = None, num_timesteps: int = None
) -> torch.nn.Module:
return nirtorch.load(
node,
partial(
_import_sinabs_module, batch_size=batch_size, num_timesteps=num_timesteps
),
)


def _extend_to_shape(x: Union[torch.Tensor, float], shape: Tuple) -> torch.Tensor:
if x.shape == shape:
return x
elif x.shape == (1,) or x.dim() == 0:
return torch.ones(*shape) * x
else:
raise ValueError(f"Not sure how to extend {x} to shape {shape}")


def _extract_sinabs_module(module: torch.nn.Module) -> Optional[nir.NIRNode]:
if type(module) in [sl.IAF, sl.IAFSqueeze]:
layer_shape = module.v_mem.shape[1:]
nir_node = nir.IF(
r=torch.ones(*layer_shape), # Discard batch dim
v_threshold=_extend_to_shape(module.spike_threshold.detach(), layer_shape),
)
return nir_node
elif type(module) in [sl.LIF, sl.LIFSqueeze]:
layer_shape = module.v_mem.shape[0]
return nir.LIF(
tau=module.tau_mem.detach(),
v_threshold=module.spike_threshold.detach(),
v_leak=torch.zeros_like(module.tau_mem.detach()),
r=torch.ones_like(module.tau_mem.detach()),
)
elif type(module) in [sl.ExpLeak, sl.ExpLeakSqueeze]:
return nir.LI(
tau=module.tau_mem.detach(),
v_leak=torch.zeros_like(module.tau_mem.detach()),
r=torch.ones_like(module.tau_mem.detach()),
)
elif isinstance(module, torch.nn.Linear):
if module.bias is None: # Add zero bias if none is present
return nir.Affine(
module.weight.detach(), torch.zeros(*module.weight.shape[:-1])
)
else:
return nir.Affine(module.weight.detach(), module.bias.detach())
elif isinstance(module, torch.nn.Conv1d):
return nir.Conv1d(
weight=module.weight.detach(),
stride=module.stride,
padding=module.padding,
dilation=module.dilation,
groups=module.groups,
bias=module.bias.detach()
if module.bias
else torch.zeros((module.weight.shape[0])),
)
elif isinstance(module, torch.nn.Conv2d):
return nir.Conv2d(
input_shape=None,
weight=module.weight.detach(),
stride=module.stride,
padding=module.padding,
dilation=module.dilation,
groups=module.groups,
bias=module.bias.detach()
if isinstance(module.bias, torch.Tensor)
else torch.zeros((module.weight.shape[0])),
)
elif isinstance(module, sl.SumPool2d):
return nir.SumPool2d(
kernel_size=_as_pair(module.kernel_size), # (Height, Width)
stride=_as_pair(
module.kernel_size if module.stride is None else module.stride
), # (Height, width)
padding=(0, 0), # (Height, width)
)
elif isinstance(module, nn.Flatten):
# Getting rid of the batch dimension for NIR
start_dim = module.start_dim - 1 if module.start_dim > 0 else module.start_dim
end_dim = module.end_dim - 1 if module.end_dim > 0 else module.end_dim
return nir.Flatten(
input_type=None,
start_dim=start_dim,
end_dim=end_dim,
)
raise NotImplementedError(f"Module {type(module)} not supported")


def to_nir(
module: torch.nn.Module, sample_data: torch.Tensor, model_name: str = "model"
) -> nir.NIRNode:
return nirtorch.extract_nir_graph(
module,
_extract_sinabs_module,
sample_data,
model_name=model_name,
ignore_dims=[0],
)
84 changes: 84 additions & 0 deletions tests/test_nir.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import nir
import torch
import torch.nn as nn

import sinabs.layers as sl
from sinabs import from_nir, to_nir


def test_iaf():
batch_size = 2
iaf = sl.IAFSqueeze(batch_size=batch_size)
graph = to_nir(iaf, torch.randn(batch_size, 10))
converted = from_nir(graph, batch_size=batch_size)

assert len(graph.nodes) == 1 + 2
assert isinstance(graph.nodes["model"], nir.IF)
assert len(graph.edges) == 0 + 2
assert iaf.batch_size == converted.model.batch_size


def test_conv2d():
batch_size = 2
conv2d = nn.Conv2d(1, 3, 2)
graph = to_nir(conv2d, torch.randn(batch_size, 1, 32, 32))
converted = from_nir(graph, batch_size=batch_size)

assert len(graph.nodes) == 1 + 2
assert isinstance(graph.nodes["model"], nir.Conv2d)
assert len(graph.edges) == 0 + 2
assert conv2d.kernel_size == converted.model.kernel_size
assert conv2d.stride == converted.model.stride
assert conv2d.padding == converted.model.padding
assert conv2d.dilation == converted.model.dilation


def test_from_sequential_to_nir():
m = nn.Sequential(
torch.nn.Linear(10, 2),
sl.ExpLeak(tau_mem=10.0),
sl.LIF(tau_mem=10.0),
torch.nn.Linear(2, 1),
)
graph = to_nir(m, torch.randn(1, 10))
assert len(graph.nodes) == 4 + 2
assert isinstance(graph.nodes["0"], nir.Affine)
assert isinstance(graph.nodes["1"], nir.LI)
assert isinstance(graph.nodes["2"], nir.LIF)
assert isinstance(graph.nodes["3"], nir.Affine)
assert len(graph.edges) == 3 + 2


def test_from_linear_to_nir():
in_features = 2
out_features = 3
m = torch.nn.Linear(in_features, out_features, bias=False)
m2 = torch.nn.Linear(in_features, out_features, bias=True)
graph = to_nir(m, torch.randn(1, in_features))
assert len(graph.nodes) == 1 + 2
assert graph.nodes["model"].weight.shape == (out_features, in_features)
assert graph.nodes["model"].bias.shape == m2.bias.shape


def test_from_nir_to_sequential():
batch_size = 4

orig_model = nn.Sequential(
torch.nn.Linear(10, 2),
sl.ExpLeakSqueeze(tau_mem=10.0, batch_size=batch_size),
sl.LIFSqueeze(tau_mem=10.0, batch_size=batch_size),
torch.nn.Linear(2, 1),
)
nir_graph = to_nir(orig_model, torch.randn(batch_size, 10))

converted_model = from_nir(nir_graph, batch_size=batch_size)
converted_modules = list(converted_model.children())
assert len(orig_model) + 2 == len(
converted_modules
) # Addition of input and output modules
torch.testing.assert_allclose(orig_model[0].weight, converted_modules[1].weight)
torch.testing.assert_allclose(orig_model[0].bias, converted_modules[1].bias)
assert type(orig_model[1]) == type(converted_modules[2])
assert type(orig_model[2]) == type(converted_modules[3])
torch.testing.assert_allclose(orig_model[3].weight, converted_modules[4].weight)
torch.testing.assert_allclose(orig_model[3].bias, converted_modules[4].bias)

0 comments on commit 45448fc

Please sign in to comment.