-
Notifications
You must be signed in to change notification settings - Fork 0
/
descriptor.py
61 lines (58 loc) · 1.87 KB
/
descriptor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from __future__ import division, print_function
import sys
from copy import deepcopy
import math
import argparse
import torch
import torch.nn.init
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.autograd import Variable
import torch.backends.cudnn as cudnn
import os
from tqdm import tqdm
import numpy as np
import random
import cv2
import copy
import PIL
import torch.nn as nn
import torch.nn.functional as F
from Utils import L2Norm, cv2_scale, np_reshape
class DesNet(nn.Module):
"""DesdNet model definition
"""
def __init__(self):
super(DesNet, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1, bias = False),
nn.BatchNorm2d(32, affine=False),
nn.ReLU(),
nn.Conv2d(32, 128, kernel_size=3, stride=2, padding=1, bias = False),
nn.BatchNorm2d(128, affine=False),
nn.ReLU(),
nn.Dropout(0.3),
nn.Conv2d(128, 128, kernel_size=8, bias = False),
nn.BatchNorm2d(128, affine=False),
)
self.features.apply(weights_init)
return
def input_norm(self,x):
flat = x.view(x.size(0), -1)
mp = torch.mean(flat, dim=1)
sp = torch.std(flat, dim=1) + 1e-7
return (x - mp.detach().unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).expand_as(x)) / sp.detach().unsqueeze(-1).unsqueeze(-1).unsqueeze(1).expand_as(x)
def forward(self, input):
x_features = self.features(self.input_norm(input))
x = x_features.view(x_features.size(0), -1)
return L2Norm()(x)
def weights_init(m):
if isinstance(m, nn.Conv2d):
nn.init.orthogonal_(m.weight.data, gain=0.6)
try:
nn.init.constant(m.bias.data, 0.01)
except:
pass
return