Skip to content

Commit

Permalink
Add torchvision.ops.FrozenBatchNorm2d to types
Browse files Browse the repository at this point in the history
  • Loading branch information
rachtibat authored Oct 4, 2023
1 parent 60a2c08 commit fc915a0
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/zennit/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# along with this library. If not, see <https://www.gnu.org/licenses/>.
'''Type definitions for convenience.'''
import torch

import torchvision

class SubclassMeta(type):
'''Meta class to bundle multiple subclasses.'''
Expand Down Expand Up @@ -71,6 +71,7 @@ class BatchNorm(metaclass=SubclassMeta):
torch.nn.modules.batchnorm.BatchNorm1d,
torch.nn.modules.batchnorm.BatchNorm2d,
torch.nn.modules.batchnorm.BatchNorm3d,
torchvision.ops.FrozenBatchNorm2d,
)


Expand Down

0 comments on commit fc915a0

Please sign in to comment.