-
Notifications
You must be signed in to change notification settings - Fork 15
/
train_segmentation.py
218 lines (197 loc) · 10 KB
/
train_segmentation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
# Mask R-CNN model for lesion segmentation in chest CT scans
# Torchvision detection package is locally re-implemented
# by Alex Ter-Sarkisov@City, University of London
# 2020
import argparse
import time
import pickle
import torch
import torchvision
import numpy as np
import os, sys
import cv2
import models.mask_net as mask_net
from models.mask_net.rpn_segmentation import AnchorGenerator
from models.mask_net.faster_rcnn import FastRCNNPredictor, TwoMLPHead
from models.mask_net.covid_mask_net import MaskRCNNHeads, MaskRCNNPredictor
from torch.utils import data
import torch.utils as utils
import datasets.dataset_segmentation as dataset
from PIL import Image as PILImage
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.patches import Rectangle
import utils
import config_segmentation as config
# main method
def main(config, main_step):
devices = ['cpu', 'cuda']
mask_classes = ['both', 'ggo', 'merge']
truncation_levels = ['0','1','2']
backbones = ['resnet50', 'resnet34', 'resnet18']
assert config.backbone_name in backbones
assert config.mask_type in mask_classes
assert config.truncation in truncation_levels
# import arguments from the config file
start_epoch, model_name, use_pretrained_resnet_backbone, num_epochs, save_dir, train_data_dir, val_data_dir, imgs_dir, gt_dir, batch_size, device, save_every, lrate, rpn_nms, mask_type, backbone_name, truncation = \
config.start_epoch, config.model_name, config.use_pretrained_resnet_backbone, config.num_epochs, config.save_dir, \
config.train_data_dir, config.val_data_dir, config.imgs_dir, config.gt_dir, config.batch_size, config.device, config.save_every, config.lrate, config.rpn_nms_th, config.mask_type, config.backbone_name, config.truncation
assert device in devices
if not save_dir in os.listdir('.'):
os.mkdir(save_dir)
if batch_size > 1:
print("The model was implemented for batch size of one")
if device == 'cuda' and torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')
print(device)
# Load the weights if provided
if config.pretrained_model is not None:
pretrained_model = torch.load(config.pretrained_model, map_location = device)
use_pretrained_resnet_backbone = False
else:
pretrained_model=None
torch.manual_seed(time.time())
##############################################################################################
# DATASETS + DATALOADERS
# Alex: could be added in the config file in the future
# parameters for the dataset
dataset_covid_pars_train = {'stage': 'train', 'gt': os.path.join(train_data_dir, gt_dir),
'data': os.path.join(train_data_dir, imgs_dir), 'mask_type':mask_type, 'ignore_small':True}
datapoint_covid_train = dataset.CovidCTData(**dataset_covid_pars_train)
dataset_covid_pars_eval = {'stage': 'eval', 'gt': os.path.join(val_data_dir, gt_dir),
'data': os.path.join(val_data_dir, imgs_dir), 'mask_type':mask_type, 'ignore_small':True}
datapoint_covid_eval = dataset.CovidCTData(**dataset_covid_pars_eval)
###############################################################################################
dataloader_covid_pars_train = {'shuffle': True, 'batch_size': batch_size}
dataloader_covid_train = data.DataLoader(datapoint_covid_train, **dataloader_covid_pars_train)
#
dataloader_covid_pars_eval = {'shuffle': False, 'batch_size': batch_size}
dataloader_covid_eval = data.DataLoader(datapoint_covid_eval, **dataloader_covid_pars_eval)
###############################################################################################
# MASK R-CNN model
# Alex: these settings could also be added to the config
if mask_type == "both":
n_c = 3
else:
n_c = 2
maskrcnn_args = {'min_size': 512, 'max_size': 1024, 'rpn_batch_size_per_image': 256, 'rpn_positive_fraction': 0.75,
'box_positive_fraction': 0.75, 'box_fg_iou_thresh': 0.75, 'box_bg_iou_thresh': 0.5,
'num_classes': None, 'box_batch_size_per_image': 256, 'rpn_nms_thresh': rpn_nms}
# Alex: for Ground glass opacity and consolidatin segmentation
# many small anchors
# use all outputs of FPN
# IMPORTANT!! For the pretrained weights, this determines the size of the anchor layer in RPN!!!!
# pretrained model must have anchors
if pretrained_model is None:
anchor_generator = AnchorGenerator(
sizes=tuple([(2, 4, 8, 16, 32) for r in range(5)]),
aspect_ratios=tuple([(0.1, 0.25, 0.5, 1, 1.5, 2) for rh in range(5)]))
else:
print("Loading the anchor generator")
sizes = pretrained_model['anchor_generator'].sizes
aspect_ratios = pretrained_model['anchor_generator'].aspect_ratios
anchor_generator = AnchorGenerator(sizes=sizes, aspect_ratios=aspect_ratios)
print(anchor_generator, anchor_generator.num_anchors_per_location())
# num_classes:3 (1+2)
# in_channels
# 256: number if channels from FPN
# For the ResNet50+FPN: keep the torchvision architecture, but with 128 features
# For lightweights models: re-implement MaskRCNNHeads with a single layer
box_head = TwoMLPHead(in_channels=256*7*7,representation_size=128)
if backbone_name == 'resnet50':
maskrcnn_heads = None
box_predictor = FastRCNNPredictor(in_channels=128, num_classes=n_c)
mask_predictor = MaskRCNNPredictor(in_channels=256, dim_reduced=256, num_classes=n_c)
else:
#Backbone->FPN->boxhead->boxpredictor
box_predictor = FastRCNNPredictor(in_channels=128, num_classes=n_c)
maskrcnn_heads = MaskRCNNHeads(in_channels=256, layers=(128,), dilation=1)
mask_predictor = MaskRCNNPredictor(in_channels=128, dim_reduced=128, num_classes=n_c)
maskrcnn_args['box_head'] = box_head
maskrcnn_args['rpn_anchor_generator'] = anchor_generator
maskrcnn_args['mask_head'] = maskrcnn_heads
maskrcnn_args['mask_predictor'] = mask_predictor
maskrcnn_args['box_predictor'] = box_predictor
# Instantiate the segmentation model
maskrcnn_model = mask_net.maskrcnn_resnet_fpn(backbone_name, truncation, pretrained_backbone=use_pretrained_resnet_backbone, **maskrcnn_args)
# pretrained?
print(maskrcnn_model.backbone.out_channels)
if pretrained_model is not None:
print("Loading pretrained weights")
maskrcnn_model.load_state_dict(pretrained_model['model_weights'])
if pretrained_model['epoch']:
start_epoch = int(pretrained_model['epoch'])+1
if 'model_name' in pretrained_model.keys():
model_name = str(pretrained_model['model_name'])
# Set to training mode
print(maskrcnn_model)
maskrcnn_model.train().to(device)
optimizer_pars = {'lr': lrate, 'weight_decay': 1e-3}
optimizer = torch.optim.Adam(list(maskrcnn_model.parameters()), **optimizer_pars)
if pretrained_model is not None and 'optimizer_state' in pretrained_model.keys():
optimizer.load_state_dict(pretrained_model['optimizer_state'])
start_time = time.time()
if start_epoch>0:
num_epochs += start_epoch
print("Start training, epoch = {:d}".format(start_epoch))
for e in range(start_epoch, num_epochs):
train_loss_epoch = main_step("train", e, dataloader_covid_train, optimizer, device, maskrcnn_model, save_every,
lrate, model_name, None, None)
eval_loss_epoch = main_step("eval", e, dataloader_covid_eval, optimizer, device, maskrcnn_model, save_every, lrate, model_name, anchor_generator, save_dir)
print(
"Epoch {0:d}: train loss = {1:.3f}, validation loss = {2:.3f}".format(e, train_loss_epoch, eval_loss_epoch))
end_time = time.time()
print("Training took {0:.1f} seconds".format(end_time - start_time))
def step(stage, e, dataloader, optimizer, device, model, save_every, lrate, model_name, anchors, save_dir):
epoch_loss = 0
for b in dataloader:
optimizer.zero_grad()
X, y = b
if device == torch.device('cuda'):
X, y['labels'], y['boxes'], y['masks'] = X.to(device), y['labels'].to(device), y['boxes'].to(device), y[
'masks'].to(device)
images = [im for im in X]
targets = []
lab = {}
# THIS IS IMPORTANT!!!!!
# get rid of the first dimension (batch)
# IF you have >1 images, make another loop
# REPEAT: DO NOT USE BATCH DIMENSION
lab['boxes'] = y['boxes'].squeeze_(0)
lab['labels'] = y['labels'].squeeze_(0)
lab['masks'] = y['masks'].squeeze_(0)
if len(lab['boxes']) > 0 and len(lab['labels']) > 0 and len(lab['masks']) > 0:
targets.append(lab)
else:
pass
# avoid empty objects
if len(targets) > 0:
loss = model(images, targets)
total_loss = 0
for k in loss.keys():
total_loss += loss[k]
if stage == "train":
total_loss.backward()
optimizer.step()
else:
pass
epoch_loss += total_loss.clone().detach().cpu().numpy()
epoch_loss = epoch_loss / len(dataloader)
if not (e+1) % save_every and stage == "eval":
model.eval()
state = {'epoch': str(e+1), 'model_name':model_name, 'model_weights': model.state_dict(),
'optimizer_state': optimizer.state_dict(), 'lrate': lrate, 'anchor_generator':anchors}
if model_name is None:
print(save_dir, "mrcnn_covid_segmentation_model_ckpt_" + str(e+1) + ".pth")
torch.save(state, os.path.join(save_dir, "mrcnn_covid_segmentation_model_ckpt_" + str(e+1) + ".pth"))
else:
torch.save(state, os.path.join(save_dir, model_name + "_ckpt_" + str(e+1) + ".pth"))
model.train()
return epoch_loss
# run the training of the segmentation algoithm
if __name__ == '__main__':
config_train = config.get_config_pars("trainval")
main(config_train, step)