-
Notifications
You must be signed in to change notification settings - Fork 3
/
main.py
58 lines (46 loc) · 1.48 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import os
import logging
import warnings
import torch
import random
import numpy as np
import torch.backends.cudnn
import wandb
from parse_args import parse_arguments
from train_classification import Experiment
from train_segmentation import Experiment as SegmentationExperiment
from globals import CONFIG
def main():
# Select whether to use deterministic behavior
if not CONFIG.use_nondeterministic:
torch.manual_seed(CONFIG.seed)
random.seed(CONFIG.seed)
np.random.seed(CONFIG.seed)
torch.backends.cudnn.benchmark = True
torch.use_deterministic_algorithms(mode=True, warn_only=True)
if CONFIG.task == 'segmentation':
experiment = SegmentationExperiment()
else:
experiment = Experiment()
experiment.fit(save_checkpoint=CONFIG.save_checkpoint)
if __name__ == '__main__':
warnings.filterwarnings('ignore', category=UserWarning)
args = parse_arguments()
CONFIG.update(vars(args))
if CONFIG.cpu:
CONFIG.device = torch.device('cpu')
CONFIG.save_dir = os.path.join('record', CONFIG.experiment_name)
os.makedirs(CONFIG.save_dir, exist_ok=True)
logging.basicConfig(
filename=os.path.join(CONFIG.save_dir, 'log.txt'),
format='%(message)s',
level=logging.INFO,
filemode='a'
)
if CONFIG.use_wandb:
wandb.init(
project='foresight-pruning',
name=CONFIG.experiment_name,
config=CONFIG
)
main()