-
Notifications
You must be signed in to change notification settings - Fork 15
/
upload.py
115 lines (95 loc) · 3.01 KB
/
upload.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import hub
import numpy as np
from PIL import Image
import argparse
import tqdm
import time
import traceback
import sys
import logging
import torchvision.datasets as datasets
NUM_WORKERS = 1
DS_OUT_PATH = "./data/places365" # optionally s3://, gcs:// or hub:// path
DOWNLOAD = False
splits = [
"train-standard",
# "val",
# "train-challenge"
]
parser = argparse.ArgumentParser(description="Hub Places365 Uploading")
parser.add_argument("data", metavar="DIR", help="path to dataset")
parser.add_argument(
"--num_workers",
type=int,
default=NUM_WORKERS,
metavar="O",
help="number of workers to allocate",
)
parser.add_argument(
"--ds_out",
type=str,
default=DS_OUT_PATH,
metavar="O",
help="dataset path to be transformed into",
)
parser.add_argument(
"--download",
type=bool,
default=DOWNLOAD,
metavar="O",
help="Download from the source http://places2.csail.mit.edu/download.html",
)
args = parser.parse_args()
def define_dataset(path: str, class_names: list = []):
ds = hub.empty(path, overwrite=True)
ds.create_tensor("images", htype="image", sample_compression="jpg")
ds.create_tensor("labels", htype="class_label", class_names=class_names)
return ds
@hub.compute
def upload_parallel(pair_in, sample_out):
filepath, target = pair_in[0], pair_in[1]
try:
img = Image.open(filepath)
if len(img.size) == 2:
img = img.convert("RGB")
arr = np.asarray(img)
sample_out.images.append(arr)
sample_out.labels.append(target)
except Exception as e:
logging.error(f"failed uploading {filepath} with target {target}")
def upload_iteration(filenames_target: list, ds: hub.Dataset):
with ds:
for filepath, target in tqdm.tqdm(filenames_target):
try:
img = Image.open(filepath)
if len(img.size) == 2:
img = img.convert("RGB")
arr = np.asarray(img)
ds.images.append(arr)
ds.labels.append(target)
except Exception as e:
logging.error(f"failed uploading {filepath} with target {target}")
if __name__ == "__main__":
for split in splits:
torch_dataset = datasets.Places365(
args.data,
split=split,
download=args.download,
)
categories = torch_dataset.load_categories()[0]
categories = list(map(lambda x: "/".join(x.split("/")[2:]), categories))
ds = define_dataset(f"{args.ds_out}-{split}", categories)
filenames_target = torch_dataset.load_file_list()
print(f"uploading {split}...")
t1 = time.time()
if args.num_workers > 1:
upload_parallel().eval(
filenames_target[0],
ds,
num_workers=args.num_workers,
scheduler="processed",
)
else:
upload_iteration(filenames_target[0], ds)
t2 = time.time()
print(f"uploading {split} took {t2-t1}s")