diff --git a/src/zennit/rules.py b/src/zennit/rules.py index c918916..e2f9ff1 100644 --- a/src/zennit/rules.py +++ b/src/zennit/rules.py @@ -471,7 +471,7 @@ def copy(self): def max_fn(self, input, kernel_size, stride, padding, dilation): raise NotImplementedError("Implement in subclass") - def sum_fn(self, input, kernel, stride, padding, dilation): + def sum_fn(self, input, kernel_size, stride, padding, dilation): raise NotImplementedError("Implement in subclass") def forward(self, module, input, output):