-
Notifications
You must be signed in to change notification settings - Fork 14
/
train.py
118 lines (105 loc) · 5.32 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import argparse
import logging
import os
import random
import importlib
import numpy as np
import torch
import torch.backends.cudnn as cudnn
from trainer import trainer_synapse
from config import get_config
from model.swin_deeplab import SwinDeepLab
parser = argparse.ArgumentParser()
parser.add_argument('--root_path', type=str,
default='../data/Synapse/train_npz', help='root dir for data')
parser.add_argument('--config_file', type=str,
default='swin_224_7_a', help='config file name w/o suffix')
parser.add_argument('--dataset', type=str,
default='Synapse', help='experiment_name')
parser.add_argument('--list_dir', type=str,
default='./lists/lists_Synapse', help='list dir')
parser.add_argument('--num_classes', type=int,
default=9, help='output channel of network')
parser.add_argument('--output_dir', type=str, default='.', help='output dir')
parser.add_argument('--max_iterations', type=int,
default=30000, help='maximum epoch number to train')
parser.add_argument('--max_epochs', type=int,
default=150, help='maximum epoch number to train')
parser.add_argument('--batch_size', type=int,
default=24, help='batch_size per gpu')
parser.add_argument('--n_gpu', type=int, default=1, help='total gpu')
parser.add_argument('--deterministic', type=int, default=1,
help='whether use deterministic training')
parser.add_argument('--base_lr', type=float, default=0.01,
help='segmentation network learning rate')
parser.add_argument('--img_size', type=int,
default=224, help='input patch size of network input')
parser.add_argument('--seed', type=int,
default=1234, help='random seed')
# parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file', )
parser.add_argument(
"--opts",
help="Modify config options by adding 'KEY VALUE' pairs. ",
default=None,
nargs='+',
)
parser.add_argument('--zip', action='store_true', help='use zipped dataset instead of folder dataset')
parser.add_argument('--cache-mode', type=str, default='part', choices=['no', 'full', 'part'],
help='no: no cache, '
'full: cache all data, '
'part: sharding the dataset into nonoverlapping pieces and only cache one piece')
parser.add_argument('--resume', help='resume from checkpoint')
parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps")
parser.add_argument('--use-checkpoint', action='store_true',
help="whether to use gradient checkpointing to save memory")
parser.add_argument('--amp-opt-level', type=str, default='O1', choices=['O0', 'O1', 'O2'],
help='mixed precision opt level, if O0, no amp is used')
parser.add_argument('--tag', help='tag of experiment')
parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
parser.add_argument('--throughput', action='store_true', help='Test throughput only')
args = parser.parse_args()
if args.dataset == "Synapse":
args.root_path = os.path.join(args.root_path, "train_npz")
config = get_config(args)
if __name__ == "__main__":
if not args.deterministic:
cudnn.benchmark = True
cudnn.deterministic = False
else:
cudnn.benchmark = False
cudnn.deterministic = True
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
dataset_name = args.dataset
dataset_config = {
'Synapse': {
'root_path': args.root_path,
'list_dir': './lists/lists_Synapse',
'num_classes': 9,
},
}
if args.batch_size != 24 and args.batch_size % 6 == 0:
args.base_lr *= args.batch_size / 24
args.num_classes = dataset_config[dataset_name]['num_classes']
args.root_path = dataset_config[dataset_name]['root_path']
args.list_dir = dataset_config[dataset_name]['list_dir']
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
model_config = importlib.import_module(f'model.configs.{args.config_file}')
net = SwinDeepLab(
model_config.EncoderConfig,
model_config.ASPPConfig,
model_config.DecoderConfig
).cuda()
if model_config.EncoderConfig.encoder_name == 'swin' and model_config.EncoderConfig.load_pretrained:
net.encoder.load_from('./pretrained_ckpt/swin_tiny_patch4_window7_224.pth')
if model_config.ASPPConfig.aspp_name == 'swin' and model_config.ASPPConfig.load_pretrained:
net.aspp.load_from('./pretrained_ckpt/swin_tiny_patch4_window7_224.pth')
if model_config.DecoderConfig.decoder_name == 'swin' and model_config.DecoderConfig.load_pretrained and not model_config.DecoderConfig.extended_load:
net.decoder.load_from('./pretrained_ckpt/swin_tiny_patch4_window7_224.pth')
if model_config.DecoderConfig.decoder_name == 'swin' and model_config.DecoderConfig.load_pretrained and model_config.DecoderConfig.extended_load:
net.decoder.load_from_extended('./pretrained_ckpt/swin_tiny_patch4_window7_224.pth')
trainer = {'Synapse': trainer_synapse,}
trainer[dataset_name](args, net, args.output_dir)