forked from JiehongLin/VI-Net
-
Notifications
You must be signed in to change notification settings - Fork 0
/
data_preprocess.py
53 lines (41 loc) · 1.96 KB
/
data_preprocess.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import json, os, glob
import _pickle as cPickle
from collections import defaultdict
def load_stats_train(data_dir, dataset,list_name):
img_path = os.path.join( dataset,list_name )
img_list = [os.path.join(data_dir,img_path.split('/')[0], line.rstrip('\n'))
for line in open(os.path.join(data_dir, img_path))]
dict_cat = defaultdict(list)
for img in img_list:
with open(img + '_label.pkl', 'rb') as f:
gts = cPickle.load(f)
for instance,cls in enumerate(gts['class_ids']):
dict_cat[cls].append((img, instance))
return dict_cat
def load_stats_test(data_dir, dataset):
result_pkl_list = glob.glob(os.path.join(data_dir, 'detection', dataset, 'results_*.pkl'))
dict_cat = defaultdict(list)
for path in result_pkl_list:
with open(path, 'rb') as f:
pred_data = cPickle.load(f)
# image_path = os.path.join(data_dir, pred_data['image_path'][5:])
for instance,cls in enumerate(pred_data['pred_class_ids']):
dict_cat[int(cls)].append((path, instance))
return dict_cat
data_dir = "../../data/NOCS/"
print("camera train")
camera_train_stats = load_stats_train(data_dir, 'camera', 'train_list.txt')
with open(os.path.join(data_dir, 'camera', 'train_category_dict.json'), 'w') as fp:
json.dump(camera_train_stats, fp)
print('real train')
real_train_stats = load_stats_train(data_dir, 'real', 'train_list.txt')
with open(os.path.join(data_dir, 'real', 'train_category_dict.json'), 'w') as fp:
json.dump(real_train_stats, fp)
print('camera test')
camera_test_stats = load_stats_test(data_dir, 'CAMERA25')
with open(os.path.join(data_dir, 'detection', 'camera_test_category_dict.json'), 'w') as fp:
json.dump(camera_test_stats, fp)
print('real test')
real_test_stats = load_stats_test(data_dir, 'REAL275')
with open(os.path.join(data_dir, 'detection', 'real_test_category_dict.json'), 'w') as fp:
json.dump(real_test_stats, fp)