Skip to content

Commit

Permalink
tox compliance
Browse files Browse the repository at this point in the history
- various non-functional changes
  • Loading branch information
jacobkauffmann committed Sep 8, 2023
1 parent 6a55fde commit 7fb44cc
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 13 deletions.
6 changes: 0 additions & 6 deletions src/zennit/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,6 @@ class MinPool2d(torch.nn.MaxPool2d):
>>> x = torch.randn(1, 1, 4, 4)
>>> pool(x)
'''
def __init__(self, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False):
super().__init__(kernel_size, stride, padding, dilation, return_indices, ceil_mode)

def forward(self, input):
'''Computes the min pool of `input`.
Expand Down Expand Up @@ -184,9 +181,6 @@ class MinPool1d(torch.nn.MaxPool1d):
>>> x = torch.randn(1, 1, 4)
>>> pool(x)
'''
def __init__(self, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False):
super().__init__(kernel_size, stride, padding, dilation, return_indices, ceil_mode)

def forward(self, input):
'''Computes the min pool of `input`.
Expand Down
19 changes: 12 additions & 7 deletions src/zennit/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,12 +469,15 @@ def copy(self):
return self.__class__(beta=self.beta)

def max_fn(self, input, kernel_size, stride, padding, dilation):
'''Computes the maximum value in a local window for each entry in the input tensor.'''
raise NotImplementedError("Implement in subclass")

def sum_fn(self, input, kernel, stride, padding, dilation):
def sum_fn(self, input, kernel_size, stride, padding, dilation):
'''Computes the sum of elements in a local window for each entry in the input tensor.'''
raise NotImplementedError("Implement in subclass")

def forward(self, module, input, output):
'''Stores the input for later use in the backward pass.'''
self.stored_tensors['input'] = input

def backward(self, module, grad_input, grad_output):
Expand Down Expand Up @@ -508,9 +511,11 @@ def __init__(self, beta=1.0):
super().__init__(-beta)

def max_fn(self, input, kernel_size, stride, padding, dilation):
'''Computes the maximum value in a local window for each entry in the input tensor.'''
return torch.nn.functional.max_pool1d(input, kernel_size, stride=stride, padding=padding, dilation=dilation)

def sum_fn(self, input, kernel_size, stride, padding, dilation):
'''Computes the sum of elements in a local window for each entry in the input tensor.'''
in_channels = input.shape[1]
kernel = torch.ones((in_channels, 1, kernel_size), device=input.device)
return torch.nn.functional.conv1d(input, weight=kernel, stride=stride, padding=padding, dilation=dilation,
Expand All @@ -526,13 +531,12 @@ class MaxTakesMost1d(TakesMostBase):
__init__(beta=1.0):
Initializes the MaxTakesMost1d class.
'''
def __init__(self, beta=1.0):
super().__init__(beta)

def max_fn(self, input, kernel_size, stride, padding, dilation):
'''Computes the maximum value in a local window for each entry in the input tensor.'''
return torch.nn.functional.max_pool1d(input, kernel_size, stride=stride, padding=padding, dilation=dilation)

def sum_fn(self, input, kernel_size, stride, padding, dilation):
'''Computes the sum of elements in a local window for each entry in the input tensor.'''
in_channels = input.shape[1]
kernel = torch.ones((in_channels, 1, kernel_size), device=input.device)
return torch.nn.functional.conv1d(input, weight=kernel, stride=stride, padding=padding, dilation=dilation,
Expand All @@ -552,9 +556,11 @@ def __init__(self, beta=1.0):
super().__init__(-beta)

def max_fn(self, input, kernel_size, stride, padding, dilation):
'''Computes the maximum value in a local window for each entry in the input tensor.'''
return torch.nn.functional.max_pool2d(input, kernel_size, stride=stride, padding=padding, dilation=dilation)

def sum_fn(self, input, kernel_size, stride, padding, dilation):
'''Computes the sum of elements in a local window for each entry in the input tensor.'''
in_channels = input.shape[1]
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size)
Expand All @@ -572,13 +578,12 @@ class MaxTakesMost2d(TakesMostBase):
__init__(beta=1.0):
Initializes the MaxTakesMost2d class.
'''
def __init__(self, beta=1.0):
super().__init__(beta)

def max_fn(self, input, kernel_size, stride, padding, dilation):
'''Computes the maximum value in a local window for each entry in the input tensor.'''
return torch.nn.functional.max_pool2d(input, kernel_size, stride=stride, padding=padding, dilation=dilation)

def sum_fn(self, input, kernel_size, stride, padding, dilation):
'''Computes the sum of elements in a local window for each entry in the input tensor.'''
in_channels = input.shape[1]
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size)
Expand Down

0 comments on commit 7fb44cc

Please sign in to comment.