From 4bd851deb91d8976a40e32bd5386aaf6d6905419 Mon Sep 17 00:00:00 2001 From: ndrplz Date: Sun, 15 Oct 2017 13:11:56 +0200 Subject: [PATCH] Implement simple data augmentation for training batches --- .../traffic_light_dataset.py | 36 +++++++++++++++++-- capstone_traffic_light_classifier/train.py | 2 +- 2 files changed, 35 insertions(+), 3 deletions(-) diff --git a/capstone_traffic_light_classifier/traffic_light_dataset.py b/capstone_traffic_light_classifier/traffic_light_dataset.py index 33a2dea..0ab7a83 100644 --- a/capstone_traffic_light_classifier/traffic_light_dataset.py +++ b/capstone_traffic_light_classifier/traffic_light_dataset.py @@ -1,4 +1,5 @@ import cv2 +import random import numpy as np from glob import glob from os.path import join @@ -45,8 +46,10 @@ def dump_to_npy(self, dump_file_path): raise IOError('Please initialize dataset first.') np.save(dump_file_path, self.dataset_npy) - def load_batch(self, batch_size): - + def load_batch(self, batch_size, augmentation=False): + """ + Load a random batch of data from the dataset + """ if not self.initialized: raise IOError('Please initialize dataset first.') @@ -64,6 +67,8 @@ def load_batch(self, batch_size): loaded += 1 X_batch = self.preprocess(X_batch) + if augmentation: + X_batch = self.perform_augmentation(X_batch) return X_batch, Y_batch @@ -76,6 +81,22 @@ def preprocess(x): x /= x.max() return x + @staticmethod + def perform_augmentation(batch): + """ + Perform simple data augmentation on training batch + """ + def coin_flip_is_head(): + return random.choice([True, False]) + + for b in range(batch.shape[0]): + if coin_flip_is_head(): + batch[b] = np.fliplr(batch[b]) + if coin_flip_is_head(): + batch[b] = np.flipud(batch[b]) + + return batch + @staticmethod def infer_label_from_frame_path(path): label = -1 @@ -100,3 +121,14 @@ def print_statistics(self): for (color, num_label) in color2label.items(): statistics[color] = np.sum(self.dataset_npy[:, 1] == num_label) print(statistics) + +if __name__ == '__main__': + # Init traffic light dataset + dataset = TrafficLightDataset() + # dataset_root = 'C:/Users/minotauro/Desktop/traffic_light_dataset' + # dataset.init_from_files(dataset_root, resize=(input_h, input_w)) + # dataset.dump_to_npy('traffic_light_dataset.npy') + dataset.init_from_npy('traffic_light_dataset.npy') + + # Load a batch of training data + x_batch, y_batch = dataset.load_batch(batch_size=16, augmentation=True) \ No newline at end of file diff --git a/capstone_traffic_light_classifier/train.py b/capstone_traffic_light_classifier/train.py index 5ce2f21..9830ddd 100644 --- a/capstone_traffic_light_classifier/train.py +++ b/capstone_traffic_light_classifier/train.py @@ -45,7 +45,7 @@ for _ in range(batches_each_epoch): # Load a batch of training data - x_batch, y_batch = dataset.load_batch(batch_size) + x_batch, y_batch = dataset.load_batch(batch_size, augmentation=True) # Actually run one training step here _, loss_this_batch = sess.run(fetches=[classifier.train_step, classifier.loss],