-
Notifications
You must be signed in to change notification settings - Fork 21
/
capsnet.py
171 lines (136 loc) · 6.41 KB
/
capsnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import torch
from torch import nn
# Available device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def squash(x, dim=-1):
squared_norm = (x ** 2).sum(dim=dim, keepdim=True)
scale = squared_norm / (1 + squared_norm)
return scale * x / (squared_norm.sqrt() + 1e-8)
class PrimaryCaps(nn.Module):
"""Primary capsule layer."""
def __init__(self, num_conv_units, in_channels, out_channels, kernel_size, stride):
super(PrimaryCaps, self).__init__()
# Each conv unit stands for a single capsule.
self.conv = nn.Conv2d(in_channels=in_channels,
out_channels=out_channels * num_conv_units,
kernel_size=kernel_size,
stride=stride)
self.out_channels = out_channels
def forward(self, x):
# Shape of x: (batch_size, in_channels, height, weight)
# Shape of out: out_capsules * (batch_size, out_channels, height, weight)
out = self.conv(x)
# Flatten out: (batch_size, out_capsules * height * weight, out_channels)
batch_size = out.shape[0]
return squash(out.contiguous().view(batch_size, -1, self.out_channels), dim=-1)
class DigitCaps(nn.Module):
"""Digit capsule layer."""
def __init__(self, in_dim, in_caps, out_caps, out_dim, num_routing):
"""
Initialize the layer.
Args:
in_dim: Dimensionality of each capsule vector.
in_caps: Number of input capsules if digits layer.
out_caps: Number of capsules in the capsule layer
out_dim: Dimensionality, of the output capsule vector.
num_routing: Number of iterations during routing algorithm
"""
super(DigitCaps, self).__init__()
self.in_dim = in_dim
self.in_caps = in_caps
self.out_caps = out_caps
self.out_dim = out_dim
self.num_routing = num_routing
self.device = device
self.W = nn.Parameter(0.01 * torch.randn(1, out_caps, in_caps, out_dim, in_dim),
requires_grad=True)
def forward(self, x):
batch_size = x.size(0)
# (batch_size, in_caps, in_dim) -> (batch_size, 1, in_caps, in_dim, 1)
x = x.unsqueeze(1).unsqueeze(4)
# W @ x =
# (1, out_caps, in_caps, out_dim, in_dim) @ (batch_size, 1, in_caps, in_dim, 1) =
# (batch_size, out_caps, in_caps, out_dims, 1)
u_hat = torch.matmul(self.W, x)
# (batch_size, out_caps, in_caps, out_dim)
u_hat = u_hat.squeeze(-1)
# detach u_hat during routing iterations to prevent gradients from flowing
temp_u_hat = u_hat.detach()
b = torch.zeros(batch_size, self.out_caps, self.in_caps, 1).to(self.device)
for route_iter in range(self.num_routing - 1):
# (batch_size, out_caps, in_caps, 1) -> Softmax along out_caps
c = b.softmax(dim=1)
# element-wise multiplication
# (batch_size, out_caps, in_caps, 1) * (batch_size, in_caps, out_caps, out_dim) ->
# (batch_size, out_caps, in_caps, out_dim) sum across in_caps ->
# (batch_size, out_caps, out_dim)
s = (c * temp_u_hat).sum(dim=2)
# apply "squashing" non-linearity along out_dim
v = squash(s)
# dot product agreement between the current output vj and the prediction uj|i
# (batch_size, out_caps, in_caps, out_dim) @ (batch_size, out_caps, out_dim, 1)
# -> (batch_size, out_caps, in_caps, 1)
uv = torch.matmul(temp_u_hat, v.unsqueeze(-1))
b += uv
# last iteration is done on the original u_hat, without the routing weights update
c = b.softmax(dim=1)
s = (c * u_hat).sum(dim=2)
# apply "squashing" non-linearity along out_dim
v = squash(s)
return v
class CapsNet(nn.Module):
"""Basic implementation of capsule network layer."""
def __init__(self):
super(CapsNet, self).__init__()
# Conv2d layer
self.conv = nn.Conv2d(1, 256, 9)
self.relu = nn.ReLU(inplace=True)
# Primary capsule
self.primary_caps = PrimaryCaps(num_conv_units=32,
in_channels=256,
out_channels=8,
kernel_size=9,
stride=2)
# Digit capsule
self.digit_caps = DigitCaps(in_dim=8,
in_caps=32 * 6 * 6,
out_caps=10,
out_dim=16,
num_routing=3)
# Reconstruction layer
self.decoder = nn.Sequential(
nn.Linear(16 * 10, 512),
nn.ReLU(inplace=True),
nn.Linear(512, 1024),
nn.ReLU(inplace=True),
nn.Linear(1024, 784),
nn.Sigmoid())
def forward(self, x):
out = self.relu(self.conv(x))
out = self.primary_caps(out)
out = self.digit_caps(out)
# Shape of logits: (batch_size, out_capsules)
logits = torch.norm(out, dim=-1)
pred = torch.eye(10).to(device).index_select(dim=0, index=torch.argmax(logits, dim=1))
# Reconstruction
batch_size = out.shape[0]
reconstruction = self.decoder((out * pred.unsqueeze(2)).contiguous().view(batch_size, -1))
return logits, reconstruction
class CapsuleLoss(nn.Module):
"""Combine margin loss & reconstruction loss of capsule network."""
def __init__(self, upper_bound=0.9, lower_bound=0.1, lmda=0.5):
super(CapsuleLoss, self).__init__()
self.upper = upper_bound
self.lower = lower_bound
self.lmda = lmda
self.reconstruction_loss_scalar = 5e-4
self.mse = nn.MSELoss(reduction='sum')
def forward(self, images, labels, logits, reconstructions):
# Shape of left / right / labels: (batch_size, num_classes)
left = (self.upper - logits).relu() ** 2 # True negative
right = (logits - self.lower).relu() ** 2 # False positive
margin_loss = torch.sum(labels * left) + self.lmda * torch.sum((1 - labels) * right)
# Reconstruction loss
reconstruction_loss = self.mse(reconstructions.contiguous().view(images.shape), images)
# Combine two losses
return margin_loss + self.reconstruction_loss_scalar * reconstruction_loss