-
Notifications
You must be signed in to change notification settings - Fork 4
/
settings.py
246 lines (206 loc) · 9.34 KB
/
settings.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
import torch
from .. import utils
def get_setting(setting_name):
"""
Maps setting names to classes.
:param setting_name: Name of the setting
:type setting_name: str
:return: Class for the setting
:rtype: GridContainerBase
"""
settings_dict = {"GridPG": GridPGContainer,
"DiFull": DiFullContainer, "DiPart": DiPartContainer}
return settings_dict[setting_name]
def eval_only_corners(setting_name):
"""
Maps setting names to whether the setting evaluates only at the top-left and bottom-right corners, as opposed to the entire grid.
:param setting_name: Name of the setting
:type setting_name: str
:return: Whether the setting evaluates only at the top-left and bottom-right corners of the grid
:rtype: bool
"""
settings_dict = {"GridPG": False,
"DiFull": True, "DiPart": True}
return settings_dict[setting_name]
class GridContainerBase(torch.nn.Module):
"""
Base class for the grid evaluation settings.
"""
def __init__(self, model, scale=2):
"""
Constructor.
:param model: Model to evaluate on.
:type model: ModelBase
:param scale: Scale parameter n in the nxn grid to use for evaluation, defaults to 2
:type scale: int, optional
"""
super(GridContainerBase, self).__init__()
assert scale >= 1 and isinstance(scale, int)
self.model = model
self.scale = scale
self.num_heads = scale * scale
def forward(self, x, start_layer=None):
raise NotImplementedError
def get_intermediate_activations(self, x, output_head_idx=0, end_layer=None):
raise NotImplementedError
def predict(self, x, *kwargs):
"""
Runs the model and returns softmax activations.
:param x: Input image or intermediate activations.
:type x: torch.Tensor
:return: Softmax activations.
:rtype: torch.Tensor
"""
self.eval()
return torch.nn.functional.softmax(self.__call__(x, *kwargs), dim=1)
class GridPGContainer(GridContainerBase):
"""
Evaluation on the GridPG setting.
"""
def __init__(self, model, scale=2):
"""
Constructor.
:param model: Model to evaluate on.
:type model: ModelBase
:param scale: scale parameter n in the nxn grid to use for evaluation, defaults to 2
:type scale: int, optional
"""
super(GridPGContainer, self).__init__(model, scale)
self.single_head = True # Since there is only one classification head
self.model.enable_classifier_kernel() # Used by VGG
def forward(self, x, output_head_idx=0, start_layer=None):
"""
Runs the model on the GridPG setting and returns output logits.
:param x: Input image or intermediate activations.
:type x: torch.Tensor
:param output_head_idx: Index of the classification head to evaluate for, defaults to 0. Ignored in the GridPG setting, since there is only one classification head.
:type output_head_idx: int, optional
:param start_layer: Convolutional layer ID to start the forward pass from, defaults to None. When None, start from the input image.
:type start_layer: int, optional
:return: Output logits.
:rtype: torch.Tensor
"""
features = self.model.get_features(
x, start_layer=start_layer, end_layer=None)
pool = self.model.get_pool(features)
logits = self.model.get_logits(pool)
return logits
@torch.no_grad()
def get_intermediate_activations(self, x, end_layer=None):
"""
Returns intermediate features from the model at a specified convolutional layer.
:param x: Input image or intermediate activations.
:type x: torch.Tensor
:param end_layer: Convolutional layer ID at which features are to be returned, defaults to None. When None, return the final feature map.
:type end_layer: int, optional
:return: Intermediate features.
:rtype: torch.Tensor
"""
return self.model.get_features(x, start_layer=None, end_layer=end_layer)
class DiFullContainer(GridContainerBase):
"""
Evaluation on the DiFull setting.
"""
def __init__(self, model, scale=2):
"""
Constructor.
:param model: Model to evaluate on.
:type model: ModelBase
:param scale: scale parameter n in the nxn grid to use for evaluation, defaults to 2
:type scale: int, optional
"""
super(DiFullContainer, self).__init__(model, scale)
self.single_head = False
self.model.disable_classifier_kernel()
def forward(self, x, output_head_idx=0, start_layer=None):
"""
Runs the model on the DiFull setting and returns output logits.
:param x: Input image or intermediate activations.
:type x: torch.Tensor
:param output_head_idx: Index of the classification head to evaluate for, defaults to 0.
:type output_head_idx: int, optional
:param start_layer: Convolutional layer ID to start the forward pass from, defaults to None. When None, start from the input image.
:type start_layer: int, optional
:return: Output logits.
:rtype: torch.Tensor
"""
assert output_head_idx >= 0 and output_head_idx < self.num_heads
# Find the coordinates to slice the grid cell at index output_head_idx from the input
y_coord, x_coord, height, width = utils.get_augmentation_range(
x.shape, self.scale, output_head_idx)
features = self.model.get_features(
x[:, :, y_coord:y_coord + height, x_coord:x_coord + width], start_layer=start_layer, end_layer=None)
pool = self.model.get_pool(features)
logits = self.model.get_logits(pool)
return logits
@torch.no_grad()
def get_intermediate_activations(self, x, end_layer=None):
"""
Returns intermediate features from the model at a specified convolutional layer.
:param x: Input image or intermediate activations.
:type x: torch.Tensor
:param end_layer: Convolutional layer ID at which features are to be returned, defaults to None. When None, return the final feature map.
:type end_layer: int, optional
:return: Intermediate features.
:rtype: torch.Tensor
"""
intermediate_activations = []
# Get intermediate activations for each grid cell by passing them separately through the model, and stitch them together at the end
for row_idx in range(self.scale):
row_activations = []
for col_idx in range(self.scale):
head_idx = row_idx * self.scale + col_idx
y_coord, x_coord, height, width = utils.get_augmentation_range(
x.shape, self.scale, head_idx)
row_activations.append(self.model.get_features(
x[:, :, y_coord:y_coord + height, x_coord:x_coord + width], start_layer=None, end_layer=end_layer))
intermediate_activations.append(torch.cat(row_activations, dim=3))
return torch.cat(intermediate_activations, dim=2)
class DiPartContainer(GridContainerBase):
"""
Evaluation on the DiPart setting.
"""
def __init__(self, model, scale=2):
"""
Constructor.
:param model: Model to evaluate on.
:type model: ModelBase
:param scale: scale parameter n in the nxn grid to use for evaluation, defaults to 2
:type scale: int, optional
"""
super(DiPartContainer, self).__init__(model, scale)
self.single_head = False
self.model.disable_classifier_kernel()
def forward(self, x, output_head_idx=0, start_layer=None):
"""
Runs the model on the DiPart setting and returns output logits.
:param x: Input image or intermediate activations.
:type x: torch.Tensor
:param output_head_idx: Index of the classification head to evaluate for, defaults to 0.
:type output_head_idx: int, optional
:param start_layer: Convolutional layer ID to start the forward pass from, defaults to None. When None, start from the input image.
:type start_layer: int, optional
:return: Output logits.
:rtype: torch.Tensor
"""
assert output_head_idx >= 0 and output_head_idx < self.num_heads
features = self.model.get_features(
x, start_layer=start_layer, end_layer=None)
y_coord, x_coord, height, width = utils.get_augmentation_range(
features.shape, self.scale, output_head_idx)
pool = self.model.get_pool(
features[:, :, y_coord:y_coord + height, x_coord:x_coord + width])
logits = self.model.get_logits(pool)
return logits
@torch.no_grad()
def get_intermediate_activations(self, x, end_layer=None):
"""
Returns intermediate features from the model at a specified convolutional layer.
:param x: Input image or intermediate activations.
:type x: torch.Tensor
:param end_layer: Convolutional layer ID at which features are to be returned, defaults to None. When None, return the final feature map.
:type end_layer: int, optional
:return: Intermediate features.
:rtype: torch.Tensor
"""
return self.model.get_features(x, start_layer=None, end_layer=end_layer)