diff --git a/src/zennit/types.py b/src/zennit/types.py index 76cf78e..33cdd30 100644 --- a/src/zennit/types.py +++ b/src/zennit/types.py @@ -17,7 +17,7 @@ # along with this library. If not, see . '''Type definitions for convenience.''' import torch - +import torchvision class SubclassMeta(type): '''Meta class to bundle multiple subclasses.''' @@ -71,6 +71,7 @@ class BatchNorm(metaclass=SubclassMeta): torch.nn.modules.batchnorm.BatchNorm1d, torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.BatchNorm3d, + torchvision.ops.FrozenBatchNorm2d, )