-
Notifications
You must be signed in to change notification settings - Fork 15
/
augment.py
44 lines (34 loc) · 1.09 KB
/
augment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import hub
import albumentations as A
from albumentations.pytorch import ToTensorV2
augment = A.Compose(
[
A.SmallestMaxSize(max_size=160),
A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.05, rotate_limit=15, p=0.5),
A.RandomCrop(height=128, width=128),
A.RGBShift(r_shift_limit=15, g_shift_limit=15, b_shift_limit=15, p=0.5),
A.RandomBrightnessContrast(p=0.5),
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
ToTensorV2(),
]
)
def transform(sample):
"""Sample is an ordered dictionary of dataset elements"""
image, label = sample["images"], sample["labels"]
image = augment(image=image)["image"]
return image, label
def loop():
# Load the dataset
ds = hub.load("hub://activeloop/cifar100-train")
# Define the dataloader with the transform
dataloader = ds.pytorch(
transform=transform,
num_workers=2,
batch_size=8,
)
# Iterate
for images, labels in dataloader:
print(images.shape, labels.shape)
break
if __name__ == "__main__":
loop()