-
Notifications
You must be signed in to change notification settings - Fork 0
/
trainer.py
executable file
·68 lines (55 loc) · 3.25 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import numpy as np
import torch.utils.data
class Trainer:
def __init__(self, device: str):
self.device = device
def train(self, model, train_ds, val_ds, loss_fn, optimizer, train_batch_size, train_dl_workers,
val_batch_size, val_dl_workers):
output_interval = 100
model.train()
train_dl = torch.utils.data.DataLoader(train_ds, collate_fn=train_ds.collate_fn,
batch_size=train_batch_size, shuffle=True, num_workers=train_dl_workers)
val_dl = torch.utils.data.DataLoader(val_ds, collate_fn=val_ds.collate_fn,
batch_size=val_batch_size, shuffle=False, num_workers=val_dl_workers)
for epoch in range(10): # loop over the dataset multiple times
running_losses = np.zeros((3,))
for i, data in enumerate(train_dl, 0):
# get the inputs
image_batch, y_batch = data
local_image_batch = image_batch.to(self.device)
local_y_batch = [(l.to(self.device), c.to(self.device)) for l, c in y_batch]
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
predicted = model(local_image_batch)
losses = loss_fn(predicted, local_y_batch)
loss = losses['total']
loss.backward()
optimizer.step()
running_losses += np.array([loss.item(),
losses['classification'].item(),
losses['localization'].item()])
# print statistics
if i % output_interval == output_interval - 1:
# TODO: Replase this with a modern version of format string or string interpolation
print('[%d, %5d] loss: %.3f, class_loss: %.3f, loc_loss: %.3f' %
(epoch + 1, i + 1, running_losses[0] / output_interval,
running_losses[1] / output_interval, running_losses[2] / output_interval))
running_losses = np.zeros((3,))
val_dl_iterator = iter(val_dl)
running_losses = np.zeros((3,))
num_batches = 0
for j in range(len(val_ds) // val_batch_size):
x_val, y_val = next(val_dl_iterator)
local_x_val = x_val.to(self.device)
local_y_val = [(l.to(self.device), c.to(self.device)) for l, c in y_val]
val_predicted = model(local_x_val)
val_losses = loss_fn(val_predicted, local_y_val)
running_losses += np.array([val_losses['total'].item(),
val_losses['classification'].item(),
val_losses['localization'].item()])
num_batches += 1
print('[%d, %5s] val loss: %.3f, val_class_loss: %.3f, val_loc_loss: %.3f' %
(epoch + 1, 'VAL', running_losses[0] / num_batches,
running_losses[1] / num_batches, running_losses[2] / num_batches))
print('Finished Training')