-
Notifications
You must be signed in to change notification settings - Fork 0
/
non_local.py
64 lines (50 loc) · 2.15 KB
/
non_local.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
from torch import nn
from torch.nn import functional as F
class PAM(nn.Module):
def __init__(self, in_channels, inter_channels=None, sub_sample=True):
super(PAM, self).__init__()
self.sub_sample = sub_sample
self.in_channels = in_channels
self.inter_channels = inter_channels
if self.inter_channels is None:
self.inter_channels = in_channels // 2
if self.inter_channels == 0:
self.inter_channels = 1
conv_nd = nn.Conv2d
max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))
bn = nn.BatchNorm2d
self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=1, stride=1, padding=0)
self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=1, stride=1, padding=0)
self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=1, stride=1, padding=0)
if sub_sample:
self.g = nn.Sequential(self.g, max_pool_layer)
self.phi = nn.Sequential(self.phi, max_pool_layer)
def forward(self, x_l, x_r, num_groups):
'''
:param x: (b, c, h, w)
:return:
'''
B, C, H, W = x_l.shape
assert C % num_groups == 0
N_c = C // num_groups
# g_x = self.g(x_r).view(batch_size, self.inter_channels, -1)
# g_x = g_x.permute(0, 2, 1)
theta_x = x_l.permute(0, 2, 3, 1)
phi_x = x_r.permute(0, 2, 1, 3)
f = x_l.new_zeros([B, num_groups, H, W, W])
for i in range(num_groups):
f[:, i, :, :, :] = torch.matmul(theta_x[:, :, :, N_c*i:N_c*(i+1)], phi_x[:, :, N_c*i:N_c*(i+1), :])//N_c
volume = f.permute(0, 1, 4, 2, 3)
volume = volume.contiguous()
M = F.softmax(f, dim=-1)
V = M.sum(4)>1
# y = torch.matmul(f_div_C, g_x)
# y = y.permute(0, 2, 1).contiguous()
# y = y.view(batch_size, self.inter_channels, *x.size()[2:])
# W_y = self.W(y)
# z = W_y + x
return volume, V