Skip to content

Commit

Permalink
Implement simple data augmentation for training batches
Browse files Browse the repository at this point in the history
  • Loading branch information
ndrplz committed Oct 15, 2017
1 parent f1b8646 commit 4bd851d
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 3 deletions.
36 changes: 34 additions & 2 deletions capstone_traffic_light_classifier/traffic_light_dataset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import cv2
import random
import numpy as np
from glob import glob
from os.path import join
Expand Down Expand Up @@ -45,8 +46,10 @@ def dump_to_npy(self, dump_file_path):
raise IOError('Please initialize dataset first.')
np.save(dump_file_path, self.dataset_npy)

def load_batch(self, batch_size):

def load_batch(self, batch_size, augmentation=False):
"""
Load a random batch of data from the dataset
"""
if not self.initialized:
raise IOError('Please initialize dataset first.')

Expand All @@ -64,6 +67,8 @@ def load_batch(self, batch_size):
loaded += 1

X_batch = self.preprocess(X_batch)
if augmentation:
X_batch = self.perform_augmentation(X_batch)

return X_batch, Y_batch

Expand All @@ -76,6 +81,22 @@ def preprocess(x):
x /= x.max()
return x

@staticmethod
def perform_augmentation(batch):
"""
Perform simple data augmentation on training batch
"""
def coin_flip_is_head():
return random.choice([True, False])

for b in range(batch.shape[0]):
if coin_flip_is_head():
batch[b] = np.fliplr(batch[b])
if coin_flip_is_head():
batch[b] = np.flipud(batch[b])

return batch

@staticmethod
def infer_label_from_frame_path(path):
label = -1
Expand All @@ -100,3 +121,14 @@ def print_statistics(self):
for (color, num_label) in color2label.items():
statistics[color] = np.sum(self.dataset_npy[:, 1] == num_label)
print(statistics)

if __name__ == '__main__':
# Init traffic light dataset
dataset = TrafficLightDataset()
# dataset_root = 'C:/Users/minotauro/Desktop/traffic_light_dataset'
# dataset.init_from_files(dataset_root, resize=(input_h, input_w))
# dataset.dump_to_npy('traffic_light_dataset.npy')
dataset.init_from_npy('traffic_light_dataset.npy')

# Load a batch of training data
x_batch, y_batch = dataset.load_batch(batch_size=16, augmentation=True)
2 changes: 1 addition & 1 deletion capstone_traffic_light_classifier/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
for _ in range(batches_each_epoch):

# Load a batch of training data
x_batch, y_batch = dataset.load_batch(batch_size)
x_batch, y_batch = dataset.load_batch(batch_size, augmentation=True)

# Actually run one training step here
_, loss_this_batch = sess.run(fetches=[classifier.train_step, classifier.loss],
Expand Down

0 comments on commit 4bd851d

Please sign in to comment.