Skip to content

Commit

Permalink
Core: Multiple Inputs and Keyword Arguments
Browse files Browse the repository at this point in the history
- use additions to forward hooks in torch 2.0.0 to pass kwargs to
  pass keyword arguments
- handle multiple inputs and outputs in core.Hook and core.BasicHook, by
  passing all required grad_outputs and inputs to the backward
  implementation

TODO:

- finish draft and test implementation
- add tests
- add documentation

- This stands in conflict with #168, but promises a better
  implementation by handling inputs and outpus as common to a single
  function, rather than individually as proposed in #168
  • Loading branch information
chr5tphr committed Apr 8, 2024
1 parent 065c821 commit 839d937
Show file tree
Hide file tree
Showing 2 changed files with 152 additions and 49 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def replace(mobj):
'click',
'numpy',
'Pillow',
'torch>=1.7.0',
'torch>=2.0.0',
'torchvision',
],
setup_requires=[
Expand Down
199 changes: 151 additions & 48 deletions src/zennit/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import functools
import weakref
from contextlib import contextmanager
from itertools import compress, repeat, islice
from inspect import signature

import torch

Expand Down Expand Up @@ -234,6 +236,43 @@ def modifier_wrapper(input, name):
return zero_params_wrapper


def uncompress(data, selector, compressed):
'''Generator which, given a compressed iterable produced by :py:obj:`itertools.compress` and (some iterable similar
to) the original data and selector used for :py:obj:`~itertools.compress`, yields values from `compressed` or
`data` depending on `selector`. `True` values in `selector` skip `data` one ahead and yield a value from
`compressed`, while `False` values yield one value from `data`.
Parameters
----------
data : iterable
The iterable (similar to the) original data. `False` values in the `selector` will be filled with values from
this iterator, while `True` values will cause this iterable to be skipped.
selector : iterable of bool
The original selector used to produce `compressed`. Chooses whether elements from `data` or from `compressed`
will be yielded.
compressed : iterable
The results of :py:obj:`itertools.compress`. Will be yielded for each `True` element in `selector`.
Yields
------
object
An element of `data` if the associated element of `selector` is `False`, otherwise an element of `compressed`
while skipping `data` one ahead.
'''
its = iter(selector)
itc = iter(compressed)
itd = iter(data)
try:
if next(its):
next(itd)
yield next(itc)
else:
yield next(itd)
except StopIteration:
return


class ParamMod:
'''Class to produce a context manager to temporarily modify parameter attributes (all by default) of a module.
Expand Down Expand Up @@ -360,6 +399,7 @@ class Identity(torch.autograd.Function):
@staticmethod
def forward(ctx, *inputs):
'''Forward identity.'''
ctx.mark_non_differentiable(*[elem for elem in inputs if not elem.requires_grad])
return inputs

@staticmethod
Expand All @@ -375,62 +415,94 @@ def __init__(self):
self.active = True
self.tensor_handles = RemovableHandleList()

def pre_forward(self, module, input):
@staticmethod
def _inject_grad_fn(args):
tensor_mask = [isinstance(elem, torch.Tensor) for elem in args]
tensors = tuple(compress(args, tensor_mask))
# tensors = [(n, elem) for elem in enumerate(args) if isinstance(elem, torch.Tensor)]

# only if gradient required
if not any(tensor.requires_grad for tensor in tensors):
return None, args, tensor_mask

# add identity to ensure .grad_fn exists and all tensors share the same .grad_fn
post_tensors = Identity.apply(*tensors)
grad_fn = next((tensor.grad_fn for tensor in post_tensors if tensor.grad_fn is not None), None)
if grad_fn is None:
raise RuntimeError('Backward hook could not be registered!')

# work-around to support in-place operations
# post_tensors = tuple(elem.clone() for elem in post_tensors)
post_args = tuple(uncompress(args, tensor_mask, post_tensors))
return grad_fn, post_args, tensor_mask

def pre_forward(self, module, args, kwargs):
'''Apply an Identity to the input before the module to register a backward hook.'''
hook_ref = weakref.ref(self)

grad_fn, post_args, input_tensor_mask = self._inject_grad_fn(args)
if grad_fn is None:
return

@functools.wraps(self.backward)
def wrapper(grad_input, grad_output):
hook = hook_ref()
if hook is not None and hook.active:
return hook.backward(module, grad_input, hook.stored_tensors['grad_output'])
return hook.backward(
module,
list(uncompress(
repeat(None),
input_tensor_mask,
grad_input,
)),
hook.stored_tensors['grad_output'],
)
return None

if not isinstance(input, tuple):
input = (input,)
# register the input tensor gradient hook
self.tensor_handles.append(grad_fn.register_hook(wrapper))

# only if gradient required
if input[0].requires_grad:
# add identity to ensure .grad_fn exists
post_input = Identity.apply(*input)
# register the input tensor gradient hook
self.tensor_handles.append(
post_input[0].grad_fn.register_hook(wrapper)
)
# work around to support in-place operations
post_input = tuple(elem.clone() for elem in post_input)
else:
# no gradient required
post_input = input
return post_input[0] if len(post_input) == 1 else post_input
return post_args, kwargs

def post_forward(self, module, input, output):
def post_forward(self, module, args, kwargs, output):
'''Register a backward-hook to the resulting tensor right after the forward.'''
hook_ref = weakref.ref(self)

single = not isinstance(output, tuple)
if single:
output = (output,)

grad_fn, post_output, output_tensor_mask = self._inject_grad_fn(output)
if grad_fn is None:
return

@functools.wraps(self.pre_backward)
def wrapper(grad_input, grad_output):
hook = hook_ref()
if hook is not None and hook.active:
return hook.pre_backward(module, grad_input, grad_output)
return hook.pre_backward(
module,
grad_input,
list(uncompress(
repeat(None),
output_tensor_mask,
grad_output
))
)
return None

if not isinstance(output, tuple):
output = (output,)
# register the output tensor gradient hook
self.tensor_handles.append(grad_fn.register_hook(wrapper))

# only if gradient required
if output[0].grad_fn is not None:
# register the output tensor gradient hook
self.tensor_handles.append(
output[0].grad_fn.register_hook(wrapper)
)
return output[0] if len(output) == 1 else output
if single:
return post_output[0]
return post_output

def pre_backward(self, module, grad_input, grad_output):
'''Store the grad_output for the backward hook'''
self.stored_tensors['grad_output'] = grad_output

def forward(self, module, input, output):
def forward(self, module, args, kwargs, output):
'''Hook applied during forward-pass'''

def backward(self, module, grad_input, grad_output):
Expand All @@ -449,11 +521,14 @@ def remove(self):

def register(self, module):
'''Register this instance by registering all hooks to the supplied module.'''
# assume with_kwargs if forward has not 3 parameters and 3rd is not called 'output'
forward_params = signature(self.forward).parameters
with_kwargs = len(forward_params) != 3 and list(forward_params)[2] != 'output'
return RemovableHandleList([
RemovableHandle(self),
module.register_forward_pre_hook(self.pre_forward),
module.register_forward_hook(self.post_forward),
module.register_forward_hook(self.forward),
module.register_forward_pre_hook(self.pre_forward, with_kwargs=True),
module.register_forward_hook(self.post_forward, with_kwargs=True),
module.register_forward_hook(self.forward, with_kwargs=with_kwargs),
])


Expand Down Expand Up @@ -517,31 +592,59 @@ def __init__(
self.gradient_mapper = gradient_mapper
self.reducer = reducer

def forward(self, module, input, output):
def forward(self, module, args, kwargs, output):
'''Forward hook to save module in-/outputs.'''
self.stored_tensors['input'] = input
self.stored_tensors['input'] = args
self.stored_tensors['kwargs'] = kwargs

def backward(self, module, grad_input, grad_output):
'''Backward hook to compute LRP based on the class attributes.'''
original_input = self.stored_tensors['input'][0].clone()
input_mask = [elem is not None for elem in self.stored_tensors['input']]
output_mask = [elem is not None for elem in grad_output]
cgrad_output = tuple(compress(grad_output, output_mask))

original_inputs = [tensor.clone() for tensor in self.stored_tensors['input']]
kwargs = self.stored_tensors['kwargs']
inputs = []
outputs = []
for in_mod, param_mod, out_mod in zip(self.input_modifiers, self.param_modifiers, self.output_modifiers):
input = in_mod(original_input).requires_grad_()
mod_args = (in_mod(tensor).requires_grad_() for tensor in compress(original_inputs, input_mask))
args = tuple(uncompress(original_inputs, input_mask, mod_args))
with ParamMod.ensure(param_mod)(module) as modified, torch.autograd.enable_grad():
output = modified.forward(input)
output = out_mod(output)
inputs.append(input)
output = modified.forward(*args, **kwargs)
if not isinstance(output, tuple):
output = (output,)
output = tuple(out_mod(tensor) for tensor in compress(output, output_mask))
inputs.append(compress(args, input_mask))
outputs.append(output)
grad_outputs = self.gradient_mapper(grad_output[0], outputs)
gradients = torch.autograd.grad(
outputs,
inputs,
grad_outputs=grad_outputs,
create_graph=grad_output[0].requires_grad

inputs = list(zip(*inputs))
outputs = list(zip(*outputs))
input_struct = [len(elem) for elem in inputs]
output_struct = [len(elem) for elem in outputs]

grad_outputs = tuple(
self.gradient_mapper(gradout, outs)
for gradout, outs in zip(cgrad_output, outputs)
)
relevance = self.reducer(inputs, gradients)
return tuple(relevance if original.shape == relevance.shape else None for original in grad_input)
inputs_flat = sum(inputs, ())
outputs_flat = sum(outputs, ())
grad_outputs_flat = sum(((grad_out,) * size for grad_out, size in zip(grad_outputs, output_struct)), ())

gradients_flat = torch.autograd.grad(
outputs_flat,
inputs_flat,
grad_outputs=grad_outputs_flat,
create_graph=any(tensor.requires_grad for tensor in cgrad_output)
)

# input_it = iter(inputs)
# inputs_re = [tuple(islice(input_it, size)) for size in input_struct]
gradient_it = iter(gradients_flat)
gradients = [tuple(islice(gradient_it, size)) for size in input_struct]

relevances = (self.reducer(inp, grad) for inp, grad in zip(inputs, gradients))
return tuple(uncompress(repeat(None), input_mask, relevances))

def copy(self):
'''Return a copy of this hook.
Expand Down

0 comments on commit 839d937

Please sign in to comment.