-
Notifications
You must be signed in to change notification settings - Fork 2
/
imagenet.py
127 lines (95 loc) · 4.12 KB
/
imagenet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
# This code is taken from official repository of "Neural Spline Flows" (https://github.com/bayesiains/nsf)
import os
import zipfile
import torch
from torch.utils.data import Dataset
import numpy as np
from torchvision.datasets.folder import (default_loader,
has_file_allowed_extension,
IMG_EXTENSIONS)
class UnlabelledImageFolder(Dataset):
def __init__(self, root, transform=None):
self.root = root
self.transform = transform
self.paths = self.find_images(os.path.join(root))
def __getitem__(self, index):
path = self.paths[index]
image = default_loader(path)
if self.transform is not None:
image = self.transform(image)
# Add a bogus label to be compatible with standard image datasets.
return image, torch.tensor([0.])
def __len__(self):
return len(self.paths)
@staticmethod
def find_images(dir):
paths = []
for fname in sorted(os.listdir(dir)):
if has_file_allowed_extension(fname, IMG_EXTENSIONS):
path = os.path.join(dir, fname)
paths.append(path)
return paths
class ImageNet32(UnlabelledImageFolder):
GOOGLE_DRIVE_FILE_ID = '1TXsg8TP5SfsSL6Gk39McCkZu9rhSQnNX'
UNZIPPED_DIR_NAME = 'imagenet32'
UNZIPPED_TRAIN_SUBDIR = 'train_32x32'
UNZIPPED_VAL_SUBDIR = 'valid_32x32'
def __init__(self, root, train=True, download=False, transform=None):
if download:
self._download(root)
img_dir = 'train' if train else 'val'
super(ImageNet32, self).__init__(os.path.join(root, img_dir),
transform=transform)
def _download(self, root):
if os.path.isdir(os.path.join(root, 'train')):
return # Downloaded already
os.makedirs(root, exist_ok=True)
zip_file = os.path.join(root, self.UNZIPPED_DIR_NAME + '.zip')
print('Downloading {}...'.format(os.path.basename(zip_file)))
download_file_from_google_drive(self.GOOGLE_DRIVE_FILE_ID, zip_file)
print('Extracting {}...'.format(os.path.basename(zip_file)))
with zipfile.ZipFile(zip_file, 'r') as fp:
fp.extractall(root)
os.remove(zip_file)
os.rename(os.path.join(root, self.UNZIPPED_DIR_NAME, self.UNZIPPED_TRAIN_SUBDIR),
os.path.join(root, 'train'))
os.rename(os.path.join(root, self.UNZIPPED_DIR_NAME, self.UNZIPPED_VAL_SUBDIR),
os.path.join(root, 'val'))
os.rmdir(os.path.join(root, self.UNZIPPED_DIR_NAME))
class ImageNet64(ImageNet32):
GOOGLE_DRIVE_FILE_ID = '1NqpYnfluJz9A2INgsn16238FUfZh9QwR'
UNZIPPED_DIR_NAME = 'imagenet64'
UNZIPPED_TRAIN_SUBDIR = 'train_64x64'
UNZIPPED_VAL_SUBDIR = 'valid_64x64'
class ImageNet64Fast(Dataset):
GOOGLE_DRIVE_FILE_ID = {
'train': '15AMmVSX-LDbP7LqC3R9Ns0RPbDI9301D',
'valid': '1Me8EhsSwWbQjQ91vRG1emkIOCgDKK4yC'
}
NPY_NAME = {
'train': 'train_64x64.npy',
'valid': 'valid_64x64.npy'
}
def __init__(self, root, train=True, download=False, transform=None):
self.transform = transform
self.root = root
if download:
self._download()
tag = 'train' if train else 'valid'
npy_data = np.load(os.path.join(root, self.NPY_NAME[tag]))
self.data = torch.from_numpy(npy_data) # Shouldn't make a copy.
def __getitem__(self, index):
img = self.data[index, ...]
if self.transform is not None:
img = self.transform(img)
# Add a bogus label to be compatible with standard image datasets.
return img, torch.tensor([0.])
def __len__(self):
return self.data.shape[0]
def _download(self):
os.makedirs(self.root, exist_ok=True)
for tag in ['train', 'valid']:
npy = os.path.join(self.root, self.NPY_NAME[tag])
if not os.path.isfile(npy):
print('Downloading {}...'.format(self.NPY_NAME[tag]))
download_file_from_google_drive(self.GOOGLE_DRIVE_FILE_ID[tag], npy)