From fc915a055e8991d1004c450ee0ccce3a08e43319 Mon Sep 17 00:00:00 2001 From: rachtibat <48683081+rachtibat@users.noreply.github.com> Date: Wed, 4 Oct 2023 12:46:16 +0200 Subject: [PATCH] Add torchvision.ops.FrozenBatchNorm2d to types --- src/zennit/types.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/zennit/types.py b/src/zennit/types.py index 76cf78e..33cdd30 100644 --- a/src/zennit/types.py +++ b/src/zennit/types.py @@ -17,7 +17,7 @@ # along with this library. If not, see . '''Type definitions for convenience.''' import torch - +import torchvision class SubclassMeta(type): '''Meta class to bundle multiple subclasses.''' @@ -71,6 +71,7 @@ class BatchNorm(metaclass=SubclassMeta): torch.nn.modules.batchnorm.BatchNorm1d, torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.BatchNorm3d, + torchvision.ops.FrozenBatchNorm2d, )