This repository has been archived by the owner on Jul 18, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 30
/
datasets.py
103 lines (71 loc) · 2.54 KB
/
datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import torch.utils.data as data
from PIL import Image
import numpy as np
from torchvision.datasets import MNIST, CIFAR10
class MNIST_truncated(data.Dataset):
def __init__(self, root, dataidxs=None, train=True, transform=None, target_transform=None, download=False):
self.root = root
self.dataidxs = dataidxs
self.train = train
self.transform = transform
self.target_transform = target_transform
self.download = download
self.data, self.target = self.__build_truncated_dataset__()
def __build_truncated_dataset__(self):
mnist_dataobj = MNIST(self.root, self.train, self.transform, self.target_transform, self.download)
data = mnist_dataobj.data
target = mnist_dataobj.targets
if self.dataidxs is not None:
data = data[self.dataidxs]
target = target[self.dataidxs]
return data, target
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target class.
"""
img, target = self.data[index], self.target[index]
# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(img.numpy(), mode='L')
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
return len(self.data)
class CIFAR10_truncated(data.Dataset):
def __init__(self, root, dataidxs=None, train=True, transform=None, target_transform=None, download=False):
self.root = root
self.dataidxs = dataidxs
self.train = train
self.transform = transform
self.target_transform = target_transform
self.download = download
self.data, self.target = self.__build_truncated_dataset__()
def __build_truncated_dataset__(self):
cifar_dataobj = CIFAR10(self.root, self.train, self.transform, self.target_transform, self.download)
data = np.array(cifar_dataobj.data)
target = np.array(cifar_dataobj.targets)
if self.dataidxs is not None:
data = data[self.dataidxs]
target = target[self.dataidxs]
return data, target
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target class.
"""
img, target = self.data[index], self.target[index]
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
return len(self.data)