-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_original.py
81 lines (65 loc) · 2.68 KB
/
train_original.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import torch
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
import pdb, os, argparse
from datetime import datetime
from model.CPD_models import CPD_VGG
from model.CPD_ResNet_models import CPD_ResNet
from data import get_loader
from utils import clip_gradient, adjust_lr
parser = argparse.ArgumentParser()
parser.add_argument('--epoch', type=int, default=100, help='epoch number')
parser.add_argument('--lr', type=float, default=1e-4, help='learning rate')
parser.add_argument('--batchsize', type=int, default=10, help='training batch size')
parser.add_argument('--trainsize', type=int, default=352, help='training dataset size')
parser.add_argument('--clip', type=float, default=0.5, help='gradient clipping margin')
parser.add_argument('--is_ResNet', type=bool, default=False, help='VGG or ResNet backbone')
parser.add_argument('--decay_rate', type=float, default=0.1, help='decay rate of learning rate')
parser.add_argument('--decay_epoch', type=int, default=50, help='every n epochs decay learning rate')
opt = parser.parse_args()
print('Learning Rate: {} ResNet: {}'.format(opt.lr, opt.is_ResNet))
# build models
if opt.is_ResNet:
model = CPD_ResNet()
else:
model = CPD_VGG()
model.cuda()
params = model.parameters()
optimizer = torch.optim.Adam(params, opt.lr)
image_root = 'path1'
gt_root = 'path2'
train_loader = get_loader(image_root, gt_root, batchsize=opt.batchsize, trainsize=opt.trainsize)
total_step = len(train_loader)
CE = torch.nn.BCEWithLogitsLoss()
def train(train_loader, model, optimizer, epoch):
model.train()
for i, pack in enumerate(train_loader, start=1):
optimizer.zero_grad()
images, gts = pack
images = Variable(images)
gts = Variable(gts)
images = images.cuda()
gts = gts.cuda()
atts, dets = model(images)
loss1 = CE(atts, gts)
loss2 = CE(dets, gts)
loss = loss1 + loss2
loss.backward()
clip_gradient(optimizer, opt.clip)
optimizer.step()
if i % 400 == 0 or i == total_step:
print('{} Epoch [{:03d}/{:03d}], Step [{:04d}/{:04d}], Loss1: {:.4f} Loss2: {:0.4f}'.
format(datetime.now(), epoch, opt.epoch, i, total_step, loss1.data, loss2.data))
if opt.is_ResNet:
save_path = 'models/CPD_Resnet/'
else:
save_path = 'models/CPD_VGG/'
if not os.path.exists(save_path):
os.makedirs(save_path)
if (epoch+1) % 5 == 0:
torch.save(model.state_dict(), save_path + 'CPD.pth' + '.%d' % epoch)
print("Let's go!")
for epoch in range(1, opt.epoch):
adjust_lr(optimizer, opt.lr, epoch, opt.decay_rate, opt.decay_epoch)
train(train_loader, model, optimizer, epoch)