diff --git a/src/zennit/layer.py b/src/zennit/layer.py index 6b52539..301c402 100644 --- a/src/zennit/layer.py +++ b/src/zennit/layer.py @@ -141,9 +141,6 @@ class MinPool2d(torch.nn.MaxPool2d): >>> x = torch.randn(1, 1, 4, 4) >>> pool(x) ''' - def __init__(self, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False): - super().__init__(kernel_size, stride, padding, dilation, return_indices, ceil_mode) - def forward(self, input): '''Computes the min pool of `input`. @@ -184,9 +181,6 @@ class MinPool1d(torch.nn.MaxPool1d): >>> x = torch.randn(1, 1, 4) >>> pool(x) ''' - def __init__(self, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False): - super().__init__(kernel_size, stride, padding, dilation, return_indices, ceil_mode) - def forward(self, input): '''Computes the min pool of `input`. diff --git a/src/zennit/rules.py b/src/zennit/rules.py index c918916..065ca50 100644 --- a/src/zennit/rules.py +++ b/src/zennit/rules.py @@ -469,12 +469,15 @@ def copy(self): return self.__class__(beta=self.beta) def max_fn(self, input, kernel_size, stride, padding, dilation): + '''Computes the maximum value in a local window for each entry in the input tensor.''' raise NotImplementedError("Implement in subclass") - def sum_fn(self, input, kernel, stride, padding, dilation): + def sum_fn(self, input, kernel_size, stride, padding, dilation): + '''Computes the sum of elements in a local window for each entry in the input tensor.''' raise NotImplementedError("Implement in subclass") def forward(self, module, input, output): + '''Stores the input for later use in the backward pass.''' self.stored_tensors['input'] = input def backward(self, module, grad_input, grad_output): @@ -508,9 +511,11 @@ def __init__(self, beta=1.0): super().__init__(-beta) def max_fn(self, input, kernel_size, stride, padding, dilation): + '''Computes the maximum value in a local window for each entry in the input tensor.''' return torch.nn.functional.max_pool1d(input, kernel_size, stride=stride, padding=padding, dilation=dilation) def sum_fn(self, input, kernel_size, stride, padding, dilation): + '''Computes the sum of elements in a local window for each entry in the input tensor.''' in_channels = input.shape[1] kernel = torch.ones((in_channels, 1, kernel_size), device=input.device) return torch.nn.functional.conv1d(input, weight=kernel, stride=stride, padding=padding, dilation=dilation, @@ -526,13 +531,12 @@ class MaxTakesMost1d(TakesMostBase): __init__(beta=1.0): Initializes the MaxTakesMost1d class. ''' - def __init__(self, beta=1.0): - super().__init__(beta) - def max_fn(self, input, kernel_size, stride, padding, dilation): + '''Computes the maximum value in a local window for each entry in the input tensor.''' return torch.nn.functional.max_pool1d(input, kernel_size, stride=stride, padding=padding, dilation=dilation) def sum_fn(self, input, kernel_size, stride, padding, dilation): + '''Computes the sum of elements in a local window for each entry in the input tensor.''' in_channels = input.shape[1] kernel = torch.ones((in_channels, 1, kernel_size), device=input.device) return torch.nn.functional.conv1d(input, weight=kernel, stride=stride, padding=padding, dilation=dilation, @@ -552,9 +556,11 @@ def __init__(self, beta=1.0): super().__init__(-beta) def max_fn(self, input, kernel_size, stride, padding, dilation): + '''Computes the maximum value in a local window for each entry in the input tensor.''' return torch.nn.functional.max_pool2d(input, kernel_size, stride=stride, padding=padding, dilation=dilation) def sum_fn(self, input, kernel_size, stride, padding, dilation): + '''Computes the sum of elements in a local window for each entry in the input tensor.''' in_channels = input.shape[1] if isinstance(kernel_size, int): kernel_size = (kernel_size, kernel_size) @@ -572,13 +578,12 @@ class MaxTakesMost2d(TakesMostBase): __init__(beta=1.0): Initializes the MaxTakesMost2d class. ''' - def __init__(self, beta=1.0): - super().__init__(beta) - def max_fn(self, input, kernel_size, stride, padding, dilation): + '''Computes the maximum value in a local window for each entry in the input tensor.''' return torch.nn.functional.max_pool2d(input, kernel_size, stride=stride, padding=padding, dilation=dilation) def sum_fn(self, input, kernel_size, stride, padding, dilation): + '''Computes the sum of elements in a local window for each entry in the input tensor.''' in_channels = input.shape[1] if isinstance(kernel_size, int): kernel_size = (kernel_size, kernel_size)