-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_student.py
712 lines (606 loc) · 27.2 KB
/
train_student.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
# Copyright (C) 2024. All rights reserved.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import print_function
import os
import argparse
import sys
import time
import numpy as np
import torch
import torch.optim as optim
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.utils.tensorboard import SummaryWriter
from models import model_dict
from models.util import ConvReg, LinearEmbed, Connector, Translator, Paraphraser
from datasets import get_cifar100_dataloaders, get_cifar100_dataloaders_sample
from datasets import get_cifar10_dataloaders, get_cifar10_dataloaders_sample
from distillers import DistillKL, HintLoss, Attention, Similarity, Correlation, VIDLoss, RKDLoss, PKT, ABLoss, FactorTransfer, KDSVD, FSP, NSTLoss, CRDLoss
from distillers import RRDLoss, DCDLoss
def parse_option():
parser = argparse.ArgumentParser('PyTorch Knowledge Distillation - Student training')
parser.add_argument('--print_freq', type=int, default=100, help='print frequency')
parser.add_argument('--tb_freq', type=int, default=500, help='tb frequency')
parser.add_argument('--save_freq', type=int, default=40, help='save frequency')
parser.add_argument('--batch_size', type=int, default=64, help='batch_size')
parser.add_argument('--num_workers', type=int, default=8, help='num of workers to use')
parser.add_argument('--epochs', type=int, default=240, help='number of training epochs')
parser.add_argument('--init_epochs', type=int, default=30, help='init training for two-stage methods')
# Optimization
parser.add_argument('--learning_rate', type=float, default=0.05, help='learning rate')
parser.add_argument('--lr_decay_epochs', type=str, default='150,180,210', help='where to decay lr, can be a list')
parser.add_argument('--lr_decay_rate', type=float, default=0.1, help='decay rate for learning rate')
parser.add_argument('--weight_decay', type=float, default=5e-4, help='weight decay')
parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
# Dataset
parser.add_argument('--dataset', type=str, default='cifar100', choices=['cifar100', 'cifar10'], help='dataset')
# Model
parser.add_argument('--model_s', type=str, default='resnet8', choices=['resnet8', 'resnet14', 'resnet20', 'resnet32',
'resnet44', 'resnet56', 'resnet110', 'resnet8x4',
'resnet32x4', 'wrn_16_1', 'wrn_16_2', 'wrn_40_1',
'wrn_40_2', 'vgg8', 'vgg11', 'vgg13', 'vgg16',
'vgg19', 'ResNet50', 'MobileNetV2', 'ShuffleV1', 'ShuffleV2'])
parser.add_argument('--path_t', type=str, default=None, help='teacher model snapshot')
# Distillation
parser.add_argument('--distill', type=str, default='kd', choices=['kd', 'hint', 'attention', 'similarity',
'correlation', 'vid', 'crd', 'kdsvd',
'fsp','rkd', 'pkt', 'abound', 'factor','nst',
'dcd', 'rrd',])
parser.add_argument('--trial', type=str, default='1', help='trial id')
parser.add_argument('-r', '--gamma', type=float, default=1, help='weight for classification')
parser.add_argument('-a', '--alpha', type=float, default=None, help='weight balance for KD')
parser.add_argument('-b', '--beta', type=float, default=None, help='weight balance for other losses')
# KL distillation
parser.add_argument('--kd_T', type=float, default=4, help='temperature for KD distillation')
# NCE distillation
parser.add_argument('--feat_dim', default=128, type=int, help='feature dimension')
parser.add_argument('--nce_k', default=16384, type=int, help='number of negative samples for NCE')
parser.add_argument('--nce_t_s', default=0.04, type=float, help='temperature parameter for softmax')
parser.add_argument('--nce_t_t', default=0.10, type=float, help='temperature parameter for softmax')
parser.add_argument('--nce_t', default=0.07, type=float, help='temperature parameter for softmax')
parser.add_argument('--nce_m', default=0.5, type=float, help='momentum for non-parametric updates')
parser.add_argument('--mode', default='exact', type=str, choices=['exact', 'relax'])
# Other
parser.add_argument('--hint_layer', default=2, type=int, choices=[0, 1, 2, 3, 4])
opt = parser.parse_args()
# Set different learning rate from these 4 models
if opt.model_s in ['MobileNetV2', 'ShuffleV1', 'ShuffleV2']:
opt.learning_rate = 0.01
# Set the path according to the environment
opt.model_path = './save/student/student_model'
opt.tb_path = './save/student/student_tensorboards'
iterations = opt.lr_decay_epochs.split(',')
opt.lr_decay_epochs = list([])
for it in iterations:
opt.lr_decay_epochs.append(int(it))
opt.model_t = get_teacher_name(opt.path_t)
opt.model_name = '{}_{}_S_{}_T_{}_r_{}_a_{}_b_{}_trial_{}'.format(opt.distill.upper(), opt.dataset, opt.model_s, opt.model_t, opt.gamma, opt.alpha, opt.beta, opt.trial)
opt.tb_folder = os.path.join(opt.tb_path, opt.model_name)
if not os.path.isdir(opt.tb_folder):
os.makedirs(opt.tb_folder)
opt.save_folder = os.path.join(opt.model_path, opt.model_name)
if not os.path.isdir(opt.save_folder):
os.makedirs(opt.save_folder)
return opt
def get_teacher_name(model_path):
""" Parse teacher name """
segments = model_path.split('/')[-2].split('_')
if segments[0] != 'wrn':
return segments[0]
else:
return segments[0] + '_' + segments[1] + '_' + segments[2]
def load_teacher(model_path, n_cls):
""" Load teacher model """
print('==> Loading teacher model')
model_t = get_teacher_name(model_path)
model = model_dict[model_t](num_classes=n_cls)
try:
model.load_state_dict(torch.load(model_path)['model'])
except:
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))['model'])
print('Teacher model loaded')
return model
def main():
best_acc = 0
opt = parse_option()
###########################
####### Tensorboard #######
###########################
logger = SummaryWriter(log_dir=opt.tb_folder)
###########################
####### Data loader #######
###########################
if opt.dataset == 'cifar100':
if opt.distill in ['crd']:
train_loader, val_loader, n_data = get_cifar100_dataloaders_sample(batch_size=opt.batch_size, num_workers=opt.num_workers, k=opt.nce_k, mode=opt.mode)
else:
train_loader, val_loader, n_data = get_cifar100_dataloaders(batch_size=opt.batch_size, num_workers=opt.num_workers, is_instance=True)
n_cls = 100
elif opt.dataset == 'cifar10':
if opt.distill in ['crd']:
train_loader, val_loader, n_data = get_cifar10_dataloaders_sample(batch_size=opt.batch_size, num_workers=opt.num_workers, k=opt.nce_k, mode=opt.mode)
else:
train_loader, val_loader, n_data = get_cifar10_dataloaders(batch_size=opt.batch_size, num_workers=opt.num_workers, is_instance=True)
n_cls = 10
else:
raise NotImplementedError(opt.dataset)
###########################
########## Model ##########
###########################
model_t = load_teacher(opt.path_t, n_cls)
model_s = model_dict[opt.model_s](num_classes=n_cls)
###########################
######## Mock data ########
###########################
data = torch.randn(2, 3, 32, 32)
model_t.eval()
model_s.eval()
feat_t, _ = model_t(data, is_feat=True)
feat_s, _ = model_s(data, is_feat=True)
###########################
######### Modules #########
###########################
module_list = nn.ModuleList([])
module_list.append(model_s)
trainable_list = nn.ModuleList([])
trainable_list.append(model_s)
###########################
######### Criteria ########
###########################
criterion_cls = nn.CrossEntropyLoss()
criterion_div = DistillKL(opt.kd_T)
if opt.distill == 'kd':
criterion_kd = DistillKL(opt.kd_T)
elif opt.distill == 'hint':
criterion_kd = HintLoss()
regress_s = ConvReg(feat_s[opt.hint_layer].shape, feat_t[opt.hint_layer].shape)
module_list.append(regress_s)
trainable_list.append(regress_s)
elif opt.distill == 'crd':
opt.s_dim = feat_s[-1].shape[1]
opt.t_dim = feat_t[-1].shape[1]
opt.n_data = n_data
criterion_kd = CRDLoss(opt)
module_list.append(criterion_kd.embed_s)
module_list.append(criterion_kd.embed_t)
trainable_list.append(criterion_kd.embed_s)
trainable_list.append(criterion_kd.embed_t)
elif opt.distill == 'attention':
criterion_kd = Attention()
elif opt.distill == 'nst':
criterion_kd = NSTLoss()
elif opt.distill == 'similarity':
criterion_kd = Similarity()
elif opt.distill == 'rkd':
criterion_kd = RKDLoss()
elif opt.distill == 'pkt':
criterion_kd = PKT()
elif opt.distill == 'kdsvd':
criterion_kd = KDSVD()
elif opt.distill == 'correlation':
criterion_kd = Correlation()
embed_s = LinearEmbed(feat_s[-1].shape[1], opt.feat_dim)
embed_t = LinearEmbed(feat_t[-1].shape[1], opt.feat_dim)
module_list.append(embed_s)
module_list.append(embed_t)
trainable_list.append(embed_s)
trainable_list.append(embed_t)
elif opt.distill == 'vid':
s_n = [f.shape[1] for f in feat_s[1:-1]]
t_n = [f.shape[1] for f in feat_t[1:-1]]
criterion_kd = nn.ModuleList([VIDLoss(s, t, t) for s, t in zip(s_n, t_n)])
trainable_list.append(criterion_kd)
elif opt.distill == 'abound':
s_shapes = [f.shape for f in feat_s[1:-1]]
t_shapes = [f.shape for f in feat_t[1:-1]]
connector = Connector(s_shapes, t_shapes)
init_trainable_list = nn.ModuleList([])
init_trainable_list.append(connector)
init_trainable_list.append(model_s.get_feat_modules())
criterion_kd = ABLoss(len(feat_s[1:-1]))
init(model_s, model_t, init_trainable_list, criterion_kd, train_loader, logger, opt)
module_list.append(connector)
elif opt.distill == 'factor':
s_shape = feat_s[-2].shape
t_shape = feat_t[-2].shape
paraphraser = Paraphraser(t_shape)
translator = Translator(s_shape, t_shape)
init_trainable_list = nn.ModuleList([])
init_trainable_list.append(paraphraser)
criterion_init = nn.MSELoss()
init(model_s, model_t, init_trainable_list, criterion_init, train_loader, logger, opt)
criterion_kd = FactorTransfer()
module_list.append(translator)
module_list.append(paraphraser)
trainable_list.append(translator)
elif opt.distill == 'fsp':
s_shapes = [s.shape for s in feat_s[:-1]]
t_shapes = [t.shape for t in feat_t[:-1]]
criterion_kd = FSP(s_shapes, t_shapes)
init_trainable_list = nn.ModuleList([])
init_trainable_list.append(model_s.get_feat_modules())
init(model_s, model_t, init_trainable_list, criterion_kd, train_loader, logger, opt)
pass
elif opt.distill == 'rrd':
opt.s_dim = feat_s[-1].shape[1]
opt.t_dim = feat_t[-1].shape[1]
criterion_kd = RRDLoss(opt)
module_list.append(criterion_kd.embed_s)
module_list.append(criterion_kd.embed_t)
trainable_list.append(criterion_kd.embed_s)
trainable_list.append(criterion_kd.embed_t)
elif opt.distill == 'dcd':
opt.s_dim = feat_s[-1].shape[1]
opt.t_dim = feat_t[-1].shape[1]
criterion_kd = DCDLoss(opt)
module_list.append(criterion_kd.embed_s)
module_list.append(criterion_kd.embed_t)
module_list.append(criterion_kd.params)
trainable_list.append(criterion_kd.embed_s)
trainable_list.append(criterion_kd.embed_t)
trainable_list.append(criterion_kd.params)
else:
raise NotImplementedError(opt.distill)
criterion_list = nn.ModuleList([])
criterion_list.append(criterion_cls) # Classification loss
criterion_list.append(criterion_div) # KL divergence loss, original knowledge distillation
criterion_list.append(criterion_kd) # Other knowledge distillation loss
###########################
######## Optimizer ########
###########################
optimizer = optim.SGD(trainable_list.parameters(), lr=opt.learning_rate, momentum=opt.momentum, weight_decay=opt.weight_decay)
# Append teacher after optimizer to avoid weight_decay
module_list.append(model_t)
###########################
######### To CUDA #########
###########################
if torch.cuda.is_available():
module_list.cuda()
criterion_list.cuda()
cudnn.benchmark = True
###########################
####### Eval teacher ######
###########################
teacher_acc, _, _ = validate(val_loader, model_t, criterion_cls, opt)
print(f'Teacher accuracy: {teacher_acc.item()}%')
###########################
######### Routine #########
###########################
for epoch in range(1, opt.epochs + 1):
adjust_learning_rate(epoch, opt, optimizer)
print("==> Training...")
time1 = time.time()
train_acc, train_loss = train(epoch, train_loader, module_list, criterion_list, optimizer, opt)
time2 = time.time()
print('Epoch {}, total time {:.2f}'.format(epoch, time2 - time1))
logger.add_scalar('train_acc', train_acc, epoch)
logger.add_scalar('train_loss', train_loss, epoch)
test_acc, tect_acc_top5, test_loss = validate(val_loader, model_s, criterion_cls, opt)
logger.add_scalar('test_acc', test_acc, epoch)
logger.add_scalar('test_loss', test_loss, epoch)
logger.add_scalar('test_acc_top5', tect_acc_top5, epoch)
# Save the best model
if test_acc > best_acc:
best_acc = test_acc
state = {
'epoch': epoch,
'model': model_s.state_dict(),
'best_acc': best_acc,
}
save_file = os.path.join(opt.save_folder, '{}_best.pth'.format(opt.model_s))
print('Saving the best model!')
torch.save(state, save_file)
# Regular saving
if epoch % opt.save_freq == 0:
print('==> Saving...')
state = {
'epoch': epoch,
'model': model_s.state_dict(),
'accuracy': test_acc,
}
save_file = os.path.join(opt.save_folder, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch))
torch.save(state, save_file)
# This best accuracy is only for printing purpose.
# The results reported in the paper/README is from the last epoch.
print('==> Best student accuracy:', best_acc)
# save model
state = {
'opt': opt,
'model': model_s.state_dict(),
}
save_file = os.path.join(opt.save_folder, '{}_last.pth'.format(opt.model_s))
torch.save(state, save_file)
logger.close()
def init(model_s, model_t, init_modules, criterion, train_loader, logger, opt):
""" Initialization """
model_t.eval()
model_s.eval()
init_modules.train()
if torch.cuda.is_available():
model_s.cuda()
model_t.cuda()
init_modules.cuda()
cudnn.benchmark = True
if opt.model_s in ['resnet8', 'resnet14', 'resnet20', 'resnet32', 'resnet44', 'resnet56', 'resnet110',
'resnet8x4', 'resnet32x4', 'wrn_16_1', 'wrn_16_2', 'wrn_40_1', 'wrn_40_2'] and \
opt.distill == 'factor':
lr = 0.01
else:
lr = opt.learning_rate
optimizer = optim.SGD(init_modules.parameters(), lr=lr, momentum=opt.momentum, weight_decay=opt.weight_decay)
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
for epoch in range(1, opt.init_epochs + 1):
batch_time.reset()
data_time.reset()
losses.reset()
end = time.time()
for idx, data in enumerate(train_loader):
if opt.distill in ['crd']:
input, target, index, contrast_idx = data
else:
input, target, index = data
data_time.update(time.time() - end)
input = input.float()
if torch.cuda.is_available():
input = input.cuda()
target = target.cuda()
index = index.cuda()
if opt.distill in ['crd']:
contrast_idx = contrast_idx.cuda()
# ==============Forward===============
preact = (opt.distill == 'abound')
feat_s, _ = model_s(input, is_feat=True, preact=preact)
with torch.no_grad():
feat_t, _ = model_t(input, is_feat=True, preact=preact)
feat_t = [f.detach() for f in feat_t]
if opt.distill == 'abound':
g_s = init_modules[0](feat_s[1:-1])
g_t = feat_t[1:-1]
loss_group = criterion(g_s, g_t)
loss = sum(loss_group)
elif opt.distill == 'factor':
f_t = feat_t[-2]
_, f_t_rec = init_modules[0](f_t)
loss = criterion(f_t_rec, f_t)
elif opt.distill == 'fsp':
loss_group = criterion(feat_s[:-1], feat_t[:-1])
loss = sum(loss_group)
else:
raise NotImplementedError('Not supported in init training: {}'.format(opt.distill))
losses.update(loss.item(), input.size(0))
# ===================Backward=====================
optimizer.zero_grad()
loss.backward()
optimizer.step()
batch_time.update(time.time() - end)
end = time.time()
# ===================Print======================
logger.add_scalar('init_train_loss', losses.avg, epoch)
print('Epoch: [{0}/{1}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'losses: {losses.val:.3f} ({losses.avg:.3f})'.format(
epoch, opt.init_epochs, batch_time=batch_time, losses=losses))
sys.stdout.flush()
def train(epoch, train_loader, module_list, criterion_list, optimizer, opt):
""" One epoch distillation """
# Set modules as train()
for module in module_list:
module.train()
# Set teacher as eval()
module_list[-1].eval()
if opt.distill == 'abound':
module_list[1].eval()
elif opt.distill == 'factor':
module_list[2].eval()
criterion_cls = criterion_list[0]
criterion_div = criterion_list[1]
criterion_kd = criterion_list[2]
model_s = module_list[0]
model_t = module_list[-1]
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()
end = time.time()
for idx, data in enumerate(train_loader):
if opt.distill in ['crd']:
input, target, index, contrast_idx = data
else:
input, target, index = data
data_time.update(time.time() - end)
input = input.float()
if torch.cuda.is_available():
input = input.cuda()
target = target.cuda()
index = index.cuda()
if opt.distill in ['crd']:
contrast_idx = contrast_idx.cuda()
# ===================Forward=====================
preact = False
if opt.distill in ['abound']:
preact = True
feat_s, logit_s = model_s(input, is_feat=True, preact=preact)
with torch.no_grad():
feat_t, logit_t = model_t(input, is_feat=True, preact=preact)
feat_t = [f.detach() for f in feat_t]
# Classification (CE) + KL div
loss_cls = criterion_cls(logit_s, target)
loss_div = criterion_div(logit_s, logit_t)
# Other kd beyond KL divergence
if opt.distill == 'kd':
loss_kd = 0
elif opt.distill == 'hint':
f_s = module_list[1](feat_s[opt.hint_layer])
f_t = feat_t[opt.hint_layer]
loss_kd = criterion_kd(f_s, f_t)
elif opt.distill == 'crd':
f_s = feat_s[-1]
f_t = feat_t[-1]
loss_kd = criterion_kd(f_s, f_t, index, contrast_idx)
elif opt.distill == 'attention':
g_s = feat_s[1:-1]
g_t = feat_t[1:-1]
loss_group = criterion_kd(g_s, g_t)
loss_kd = sum(loss_group)
elif opt.distill == 'nst':
g_s = feat_s[1:-1]
g_t = feat_t[1:-1]
loss_group = criterion_kd(g_s, g_t)
loss_kd = sum(loss_group)
elif opt.distill == 'similarity':
g_s = [feat_s[-2]]
g_t = [feat_t[-2]]
loss_group = criterion_kd(g_s, g_t)
loss_kd = sum(loss_group)
elif opt.distill == 'rkd':
f_s = feat_s[-1]
f_t = feat_t[-1]
loss_kd = criterion_kd(f_s, f_t)
elif opt.distill == 'pkt':
f_s = feat_s[-1]
f_t = feat_t[-1]
loss_kd = criterion_kd(f_s, f_t)
elif opt.distill == 'kdsvd':
g_s = feat_s[1:-1]
g_t = feat_t[1:-1]
loss_group = criterion_kd(g_s, g_t)
loss_kd = sum(loss_group)
elif opt.distill == 'correlation':
f_s = module_list[1](feat_s[-1])
f_t = module_list[2](feat_t[-1])
loss_kd = criterion_kd(f_s, f_t)
elif opt.distill == 'vid':
g_s = feat_s[1:-1]
g_t = feat_t[1:-1]
loss_group = [c(f_s, f_t) for f_s, f_t, c in zip(g_s, g_t, criterion_kd)]
loss_kd = sum(loss_group)
elif opt.distill == 'abound':
# can also add loss to this stage
loss_kd = 0
elif opt.distill == 'fsp':
# can also add loss to this stage
loss_kd = 0
elif opt.distill == 'factor':
factor_s = module_list[1](feat_s[-2])
factor_t = module_list[2](feat_t[-2], is_factor=True)
loss_kd = criterion_kd(factor_s, factor_t)
elif opt.distill == 'rrd':
f_s = feat_s[-1]
f_t = feat_t[-1]
loss_kd = criterion_kd(f_s, f_t)
elif opt.distill == 'dcd':
f_s = feat_s[-1]
f_t = feat_t[-1]
loss_kd = criterion_kd(f_s, f_t)
else:
raise NotImplementedError(opt.distill)
loss = opt.gamma * loss_cls + opt.alpha * loss_div + opt.beta * loss_kd
acc1, acc5 = accuracy(logit_s, target, topk=(1, 5))
losses.update(loss.item(), input.size(0))
top1.update(acc1[0], input.size(0))
top5.update(acc5[0], input.size(0))
# ===================Backward=====================
optimizer.zero_grad()
loss.backward()
optimizer.step()
# ===================Meters=====================
batch_time.update(time.time() - end)
end = time.time()
# ===================Print======================
if idx % opt.print_freq == 0:
print('Epoch: [{0}][{1}/{2}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t'
'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
epoch, idx, len(train_loader), batch_time=batch_time,
data_time=data_time, loss=losses, top1=top1, top5=top5))
sys.stdout.flush()
print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
.format(top1=top1, top5=top5))
return top1.avg, losses.avg
def validate(val_loader, model, criterion, opt):
""" Validation """
batch_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()
# Switch to evaluate mode
model.eval()
with torch.no_grad():
end = time.time()
for idx, (input, target) in enumerate(val_loader):
input = input.float()
if torch.cuda.is_available():
input = input.cuda()
target = target.cuda()
# Compute output
output = model(input)
loss = criterion(output, target)
# Measure accuracy and record loss
acc1, acc5 = accuracy(output, target, topk=(1, 5))
losses.update(loss.item(), input.size(0))
top1.update(acc1[0], input.size(0))
top5.update(acc5[0], input.size(0))
# Measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
# Print info
if idx % opt.print_freq == 0:
print('Test: [{0}/{1}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t'
'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
idx, len(val_loader), batch_time=batch_time, loss=losses,
top1=top1, top5=top5))
print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'.format(top1=top1, top5=top5))
return top1.avg, top5.avg, losses.avg
def adjust_learning_rate_new(epoch, optimizer, LUT):
""" Learning rate schedule according to RotNet """
lr = next((lr for (max_epoch, lr) in LUT if max_epoch > epoch), LUT[-1][1])
for param_group in optimizer.param_groups:
param_group['lr'] = lr
def adjust_learning_rate(epoch, opt, optimizer):
""" Sets the learning rate to the initial LR decayed by decay rate every steep step """
steps = np.sum(epoch > np.asarray(opt.lr_decay_epochs))
if steps > 0:
new_lr = opt.learning_rate * (opt.lr_decay_rate ** steps)
for param_group in optimizer.param_groups:
param_group['lr'] = new_lr
class AverageMeter(object):
""" Computes and stores the average and current value """
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def accuracy(output, target, topk=(1,)):
""" Computes the accuracy over the k top predictions for the specified values of k """
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
if __name__ == '__main__':
main()