diff --git a/setup.py b/setup.py index ebbc0e9..6cdeb73 100644 --- a/setup.py +++ b/setup.py @@ -54,7 +54,7 @@ def replace(mobj): 'click', 'numpy', 'Pillow', - 'torch>=1.7.0', + 'torch>=2.0.0', 'torchvision', ], setup_requires=[ diff --git a/src/zennit/core.py b/src/zennit/core.py index 947f468..364bffd 100644 --- a/src/zennit/core.py +++ b/src/zennit/core.py @@ -19,6 +19,8 @@ import functools import weakref from contextlib import contextmanager +from itertools import compress, repeat, islice, chain +from inspect import signature import torch @@ -234,6 +236,43 @@ def modifier_wrapper(input, name): return zero_params_wrapper +def uncompress(data, selector, compressed): + '''Generator which, given a compressed iterable produced by :py:obj:`itertools.compress` and (some iterable similar + to) the original data and selector used for :py:obj:`~itertools.compress`, yields values from `compressed` or + `data` depending on `selector`. `True` values in `selector` skip `data` one ahead and yield a value from + `compressed`, while `False` values yield one value from `data`. + + Parameters + ---------- + data : iterable + The iterable (similar to the) original data. `False` values in the `selector` will be filled with values from + this iterator, while `True` values will cause this iterable to be skipped. + selector : iterable of bool + The original selector used to produce `compressed`. Chooses whether elements from `data` or from `compressed` + will be yielded. + compressed : iterable + The results of :py:obj:`itertools.compress`. Will be yielded for each `True` element in `selector`. + + Yields + ------ + object + An element of `data` if the associated element of `selector` is `False`, otherwise an element of `compressed` + while skipping `data` one ahead. + + ''' + its = iter(selector) + itc = iter(compressed) + itd = iter(data) + try: + if next(its): + next(itd) + yield next(itc) + else: + yield next(itd) + except StopIteration: + return + + class ParamMod: '''Class to produce a context manager to temporarily modify parameter attributes (all by default) of a module. @@ -360,6 +399,7 @@ class Identity(torch.autograd.Function): @staticmethod def forward(ctx, *inputs): '''Forward identity.''' + ctx.mark_non_differentiable(*[elem for elem in inputs if not elem.requires_grad]) return inputs @staticmethod @@ -375,62 +415,94 @@ def __init__(self): self.active = True self.tensor_handles = RemovableHandleList() - def pre_forward(self, module, input): + @staticmethod + def _inject_grad_fn(args): + tensor_mask = tuple(isinstance(elem, torch.Tensor) for elem in args) + tensors = tuple(compress(args, tensor_mask)) + # tensors = [(n, elem) for elem in enumerate(args) if isinstance(elem, torch.Tensor)] + + # only if gradient required + if not any(tensor.requires_grad for tensor in tensors): + return None, args, tensor_mask + + # add identity to ensure .grad_fn exists and all tensors share the same .grad_fn + post_tensors = Identity.apply(*tensors) + grad_fn = next((tensor.grad_fn for tensor in post_tensors if tensor.grad_fn is not None), None) + if grad_fn is None: + raise RuntimeError('Backward hook could not be registered!') + + # work-around to support in-place operations + post_tensors = tuple(elem.clone() for elem in post_tensors) + post_args = tuple(uncompress(args, tensor_mask, post_tensors)) + return grad_fn, post_args, tensor_mask + + def pre_forward(self, module, args, kwargs): '''Apply an Identity to the input before the module to register a backward hook.''' hook_ref = weakref.ref(self) + grad_fn, post_args, input_tensor_mask = self._inject_grad_fn(args) + if grad_fn is None: + return + @functools.wraps(self.backward) def wrapper(grad_input, grad_output): hook = hook_ref() if hook is not None and hook.active: - return hook.backward(module, grad_input, hook.stored_tensors['grad_output']) + return hook.backward( + module, + list(uncompress( + repeat(None), + input_tensor_mask, + grad_input, + )), + hook.stored_tensors['grad_output'], + ) return None - if not isinstance(input, tuple): - input = (input,) + # register the input tensor gradient hook + self.tensor_handles.append(grad_fn.register_hook(wrapper)) - # only if gradient required - if input[0].requires_grad: - # add identity to ensure .grad_fn exists - post_input = Identity.apply(*input) - # register the input tensor gradient hook - self.tensor_handles.append( - post_input[0].grad_fn.register_hook(wrapper) - ) - # work around to support in-place operations - post_input = tuple(elem.clone() for elem in post_input) - else: - # no gradient required - post_input = input - return post_input[0] if len(post_input) == 1 else post_input + return post_args, kwargs - def post_forward(self, module, input, output): + def post_forward(self, module, args, kwargs, output): '''Register a backward-hook to the resulting tensor right after the forward.''' hook_ref = weakref.ref(self) + single = not isinstance(output, tuple) + if single: + output = (output,) + + grad_fn, post_output, output_tensor_mask = self._inject_grad_fn(output) + if grad_fn is None: + return + @functools.wraps(self.pre_backward) def wrapper(grad_input, grad_output): hook = hook_ref() if hook is not None and hook.active: - return hook.pre_backward(module, grad_input, grad_output) + return hook.pre_backward( + module, + grad_input, + tuple(uncompress( + repeat(None), + output_tensor_mask, + grad_output + )) + ) return None - if not isinstance(output, tuple): - output = (output,) + # register the output tensor gradient hook + self.tensor_handles.append(grad_fn.register_hook(wrapper)) - # only if gradient required - if output[0].grad_fn is not None: - # register the output tensor gradient hook - self.tensor_handles.append( - output[0].grad_fn.register_hook(wrapper) - ) - return output[0] if len(output) == 1 else output + if single: + return post_output[0] + return post_output def pre_backward(self, module, grad_input, grad_output): '''Store the grad_output for the backward hook''' self.stored_tensors['grad_output'] = grad_output - def forward(self, module, input, output): + def forward(self, module, args, kwargs, output): '''Hook applied during forward-pass''' def backward(self, module, grad_input, grad_output): @@ -449,11 +521,14 @@ def remove(self): def register(self, module): '''Register this instance by registering all hooks to the supplied module.''' + # assume with_kwargs if forward has not 3 parameters and 3rd is not called 'output' + forward_params = signature(self.forward).parameters + with_kwargs = len(forward_params) != 3 and list(forward_params)[2] != 'output' return RemovableHandleList([ RemovableHandle(self), - module.register_forward_pre_hook(self.pre_forward), - module.register_forward_hook(self.post_forward), - module.register_forward_hook(self.forward), + module.register_forward_pre_hook(self.pre_forward, with_kwargs=True), + module.register_forward_hook(self.post_forward, with_kwargs=True), + module.register_forward_hook(self.forward, with_kwargs=with_kwargs), ]) @@ -517,31 +592,61 @@ def __init__( self.gradient_mapper = gradient_mapper self.reducer = reducer - def forward(self, module, input, output): + def forward(self, module, args, kwargs, output): '''Forward hook to save module in-/outputs.''' - self.stored_tensors['input'] = input + self.stored_tensors['input'] = args + self.stored_tensors['kwargs'] = kwargs def backward(self, module, grad_input, grad_output): '''Backward hook to compute LRP based on the class attributes.''' - original_input = self.stored_tensors['input'][0].clone() + input_mask = [elem is not None for elem in self.stored_tensors['input']] + output_mask = [elem is not None for elem in grad_output] + cgrad_output = tuple(compress(grad_output, output_mask)) + + original_inputs = [tensor.clone() for tensor in self.stored_tensors['input']] + kwargs = self.stored_tensors['kwargs'] inputs = [] outputs = [] for in_mod, param_mod, out_mod in zip(self.input_modifiers, self.param_modifiers, self.output_modifiers): - input = in_mod(original_input).requires_grad_() + mod_args = (in_mod(tensor).requires_grad_() for tensor in compress(original_inputs, input_mask)) + args = tuple(uncompress(original_inputs, input_mask, mod_args)) with ParamMod.ensure(param_mod)(module) as modified, torch.autograd.enable_grad(): - output = modified.forward(input) - output = out_mod(output) - inputs.append(input) + output = modified.forward(*args, **kwargs) + if not isinstance(output, tuple): + output = (output,) + output = tuple(out_mod(tensor) for tensor in compress(output, output_mask)) + inputs.append(compress(args, input_mask)) outputs.append(output) - grad_outputs = self.gradient_mapper(grad_output[0], outputs) - gradients = torch.autograd.grad( - outputs, - inputs, + + inputs = list(zip(*inputs)) + outputs = list(zip(*outputs)) + input_struct = [len(elem) for elem in inputs] + output_struct = [len(elem) for elem in outputs] + + grad_outputs = tuple( + self.gradient_mapper(gradout, outs) + for gradout, outs in zip(cgrad_output, outputs) + ) + inputs_flat = tuple(chain.from_iterable(inputs)) + outputs_flat = tuple(chain.from_iterable(outputs)) + if not all(isinstance(elem, torch.Tensor) for elem in grad_outputs): + # if there is only a single output modifier, grad_outputs may contain tensors + grad_outputs = tuple(chain.from_iterable(grad_outputs)) + + gradients_flat = torch.autograd.grad( + outputs_flat, + inputs_flat, grad_outputs=grad_outputs, - create_graph=grad_output[0].requires_grad + create_graph=any(tensor.requires_grad for tensor in cgrad_output) ) - relevance = self.reducer(inputs, gradients) - return tuple(relevance if original.shape == relevance.shape else None for original in grad_input) + + # input_it = iter(inputs) + # inputs_re = [tuple(islice(input_it, size)) for size in input_struct] + gradient_it = iter(gradients_flat) + gradients = [tuple(islice(gradient_it, size)) for size in input_struct] + + relevances = (self.reducer(inp, grad) for inp, grad in zip(inputs, gradients)) + return tuple(uncompress(repeat(None), input_mask, relevances)) def copy(self): '''Return a copy of this hook.