-
Notifications
You must be signed in to change notification settings - Fork 68
/
main_download_pretrained_models.py
142 lines (110 loc) · 7.52 KB
/
main_download_pretrained_models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import argparse
import os
import requests
import re
"""
How to use:
download all the models:
python main_download_pretrained_models.py --models "all" --model_dir "model_zoo"
download DnCNN models:
python main_download_pretrained_models.py --models "DnCNN" --model_dir "model_zoo"
download SRMD models:
python main_download_pretrained_models.py --models "SRMD" --model_dir "model_zoo"
download BSRGAN models:
python main_download_pretrained_models.py --models "BSRGAN" --model_dir "model_zoo"
download FFDNet models:
python main_download_pretrained_models.py --models "FFDNet" --model_dir "model_zoo"
download DPSR models:
python main_download_pretrained_models.py --models "DPSR" --model_dir "model_zoo"
download SwinIR models:
python main_download_pretrained_models.py --models "SwinIR" --model_dir "model_zoo"
download VRT models:
python main_download_pretrained_models.py --models "VRT" --model_dir "model_zoo"
download other models:
python main_download_pretrained_models.py --models "others" --model_dir "model_zoo"
------------------------------------------------------------------
download 'dncnn_15.pth' and 'dncnn_50.pth'
python main_download_pretrained_models.py --models "dncnn_15.pth dncnn_50.pth" --model_dir "model_zoo"
------------------------------------------------------------------
download DnCNN models and 'BSRGAN.pth'
python main_download_pretrained_models.py --models "DnCNN BSRGAN.pth" --model_dir "model_zoo"
"""
def download_pretrained_model(model_dir='model_zoo', model_name='dncnn3.pth'):
if os.path.exists(os.path.join(model_dir, model_name)):
print(f'already exists, skip downloading [{model_name}]')
else:
os.makedirs(model_dir, exist_ok=True)
if 'SwinIR' in model_name:
url = 'https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/{}'.format(model_name)
elif 'VRT' in model_name:
url = 'https://github.com/JingyunLiang/VRT/releases/download/v0.0/{}'.format(model_name)
else:
url = 'https://github.com/cszn/KAIR/releases/download/v1.0/{}'.format(model_name)
r = requests.get(url, allow_redirects=True)
print(f'downloading [{model_dir}/{model_name}] ...')
open(os.path.join(model_dir, model_name), 'wb').write(r.content)
print('done!')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--models',
type=lambda s: re.split(' |, ', s),
default = "dncnn3.pth",
help='comma or space delimited list of characters, e.g., "DnCNN", "DnCNN BSRGAN.pth", "dncnn_15.pth dncnn_50.pth"')
parser.add_argument('--model_dir', type=str, default='model_zoo', help='path of model_zoo')
args = parser.parse_args()
print(f'trying to download {args.models}')
method_model_zoo = {'DnCNN': ['dncnn_15.pth', 'dncnn_25.pth', 'dncnn_50.pth', 'dncnn3.pth', 'dncnn_color_blind.pth', 'dncnn_gray_blind.pth'],
'SRMD': ['srmdnf_x2.pth', 'srmdnf_x3.pth', 'srmdnf_x4.pth', 'srmd_x2.pth', 'srmd_x3.pth', 'srmd_x4.pth'],
'DPSR': ['dpsr_x2.pth', 'dpsr_x3.pth', 'dpsr_x4.pth', 'dpsr_x4_gan.pth'],
'FFDNet': ['ffdnet_color.pth', 'ffdnet_gray.pth', 'ffdnet_color_clip.pth', 'ffdnet_gray_clip.pth'],
'USRNet': ['usrgan.pth', 'usrgan_tiny.pth', 'usrnet.pth', 'usrnet_tiny.pth'],
'DPIR': ['drunet_gray.pth', 'drunet_color.pth', 'drunet_deblocking_color.pth', 'drunet_deblocking_grayscale.pth'],
'BSRGAN': ['BSRGAN.pth', 'BSRNet.pth', 'BSRGANx2.pth'],
'IRCNN': ['ircnn_color.pth', 'ircnn_gray.pth'],
'SCUNet': ['scunet_gray_15.pth', 'scunet_gray_25.pth', 'scunet_gray_50.pth', 'scunet_color_15.pth', 'scunet_color_25.pth', 'scunet_color_50.pth', 'scunet_color_real_psnr.pth', 'scunet_color_real_gan.pth'],
'SwinIR': ['001_classicalSR_DF2K_s64w8_SwinIR-M_x2.pth', '001_classicalSR_DF2K_s64w8_SwinIR-M_x3.pth',
'001_classicalSR_DF2K_s64w8_SwinIR-M_x4.pth', '001_classicalSR_DF2K_s64w8_SwinIR-M_x8.pth',
'001_classicalSR_DIV2K_s48w8_SwinIR-M_x2.pth', '001_classicalSR_DIV2K_s48w8_SwinIR-M_x3.pth',
'001_classicalSR_DIV2K_s48w8_SwinIR-M_x4.pth', '001_classicalSR_DIV2K_s48w8_SwinIR-M_x8.pth',
'002_lightweightSR_DIV2K_s64w8_SwinIR-S_x2.pth', '002_lightweightSR_DIV2K_s64w8_SwinIR-S_x3.pth',
'002_lightweightSR_DIV2K_s64w8_SwinIR-S_x4.pth', '003_realSR_BSRGAN_DFO_s64w8_SwinIR-M_x4_GAN.pth',
'003_realSR_BSRGAN_DFO_s64w8_SwinIR-M_x4_PSNR.pth', '004_grayDN_DFWB_s128w8_SwinIR-M_noise15.pth',
'004_grayDN_DFWB_s128w8_SwinIR-M_noise25.pth', '004_grayDN_DFWB_s128w8_SwinIR-M_noise50.pth',
'005_colorDN_DFWB_s128w8_SwinIR-M_noise15.pth', '005_colorDN_DFWB_s128w8_SwinIR-M_noise25.pth',
'005_colorDN_DFWB_s128w8_SwinIR-M_noise50.pth', '006_CAR_DFWB_s126w7_SwinIR-M_jpeg10.pth',
'006_CAR_DFWB_s126w7_SwinIR-M_jpeg20.pth', '006_CAR_DFWB_s126w7_SwinIR-M_jpeg30.pth',
'006_CAR_DFWB_s126w7_SwinIR-M_jpeg40.pth'],
'VRT': ['001_VRT_videosr_bi_REDS_6frames.pth', '002_VRT_videosr_bi_REDS_16frames.pth',
'003_VRT_videosr_bi_Vimeo_7frames.pth', '004_VRT_videosr_bd_Vimeo_7frames.pth',
'005_VRT_videodeblurring_DVD.pth', '006_VRT_videodeblurring_GoPro.pth',
'007_VRT_videodeblurring_REDS.pth', '008_VRT_videodenoising_DAVIS.pth'],
'others': ['msrresnet_x4_psnr.pth', 'msrresnet_x4_gan.pth', 'imdn_x4.pth', 'RRDB.pth', 'ESRGAN.pth',
'FSSR_DPED.pth', 'FSSR_JPEG.pth', 'RealSR_DPED.pth', 'RealSR_JPEG.pth']
}
method_zoo = list(method_model_zoo.keys())
model_zoo = []
for b in list(method_model_zoo.values()):
model_zoo += b
if 'all' in args.models:
for method in method_zoo:
for model_name in method_model_zoo[method]:
download_pretrained_model(args.model_dir, model_name)
else:
for method_model in args.models:
if method_model in method_zoo: # method, need for loop
for model_name in method_model_zoo[method_model]:
if 'SwinIR' in model_name:
download_pretrained_model(os.path.join(args.model_dir, 'swinir'), model_name)
elif 'VRT' in model_name:
download_pretrained_model(os.path.join(args.model_dir, 'vrt'), model_name)
else:
download_pretrained_model(args.model_dir, model_name)
elif method_model in model_zoo: # model, do not need for loop
if 'SwinIR' in method_model:
download_pretrained_model(os.path.join(args.model_dir, 'swinir'), method_model)
elif 'VRT' in method_model:
download_pretrained_model(os.path.join(args.model_dir, 'vrt'), method_model)
else:
download_pretrained_model(args.model_dir, method_model)
else:
print(f'Do not find {method_model} from the pre-trained model zoo!')