diff --git a/share/example/feed_forward.py b/share/example/feed_forward.py index 5b95424..a648c48 100644 --- a/share/example/feed_forward.py +++ b/share/example/feed_forward.py @@ -12,7 +12,9 @@ from zennit.attribution import Gradient, SmoothGrad, IntegratedGradients, Occlusion from zennit.composites import COMPOSITES +from zennit.core import Hook from zennit.image import imsave, CMAPS +from zennit.layer import Sum from zennit.torchvision import VGGCanonizer, ResNetCanonizer @@ -34,6 +36,17 @@ } +class SumSingle(Hook): + def __init__(self, dim=1): + super().__init__() + self.dim = dim + + def backward(self, module, grad_input, grad_output): + elems = [torch.zeros_like(grad_output[0])] * (grad_input[0].shape[-1]) + elems[self.dim] = grad_output[0] + return (torch.stack(elems, dim=-1),) + + class BatchNormalize: def __init__(self, mean, std, device=None): self.mean = torch.tensor(mean, device=device)[None, :, None, None] @@ -77,6 +90,7 @@ def find_classes(self, directory): @click.option('--cpu/--gpu', default=True) @click.option('--shuffle/--no-shuffle', default=False) @click.option('--with-bias/--no-bias', default=True) +@click.option('--with-residual/--no-residual', default=True) @click.option('--relevance-norm', type=click.Choice(['symmetric', 'absolute', 'unaligned']), default='symmetric') @click.option('--cmap', type=click.Choice(list(CMAPS)), default='coldnhot') @click.option('--level', type=float, default=1.0) @@ -95,6 +109,7 @@ def main( cpu, shuffle, with_bias, + with_residual, cmap, level, relevance_norm, @@ -164,6 +179,9 @@ def attr_output_fn(output, target): # the highest and lowest pixel values for the ZBox rule composite_kwargs['low'] = norm_fn(torch.zeros(*shape, device=device)) composite_kwargs['high'] = norm_fn(torch.ones(*shape, device=device)) + if not with_residual and 'resnet' in model_name: + # skip the residual connection through the Sum added by the ResNetCanonizer + composite_kwargs['layer_map'] = [(Sum, SumSingle(1))] # provide the name 'bias' in zero_params if no bias should be used to compute the relevance if not with_bias and composite_name in [