diff --git a/capstone_traffic_light_classifier/traffic_light_dataset.py b/capstone_traffic_light_classifier/traffic_light_dataset.py index 4f534af..5d37ec9 100644 --- a/capstone_traffic_light_classifier/traffic_light_dataset.py +++ b/capstone_traffic_light_classifier/traffic_light_dataset.py @@ -72,6 +72,21 @@ def load_batch(self, batch_size, augmentation=False): return X_batch, Y_batch + def print_statistics(self): + """ + Print simple statistics on the number of samples in the dataset. + :return: + """ + if not self.initialized: + raise IOError('Please initialize dataset first.') + + color2label = {'none': 0, 'red': 1, 'yellow': 2, 'green': 3} + + statistics = {} + for (color, num_label) in color2label.items(): + statistics[color] = np.sum(self.dataset_npy[:, 1] == num_label) + print(statistics) + @staticmethod def preprocess(x): """ @@ -113,26 +128,3 @@ def infer_label_from_frame_path(path): elif 'green' in path: label = 3 return label - - def print_statistics(self): - - if not self.initialized: - raise IOError('Please initialize dataset first.') - - color2label = {'none': 0, 'red': 1, 'yellow': 2, 'green': 3} - - statistics = {} - for (color, num_label) in color2label.items(): - statistics[color] = np.sum(self.dataset_npy[:, 1] == num_label) - print(statistics) - -if __name__ == '__main__': - # Init traffic light dataset - dataset = TrafficLightDataset() - # dataset_root = 'C:/Users/minotauro/Desktop/traffic_light_dataset' - # dataset.init_from_files(dataset_root, resize=(input_h, input_w)) - # dataset.dump_to_npy('traffic_light_dataset.npy') - dataset.init_from_npy('traffic_light_dataset.npy') - - # Load a batch of training data - x_batch, y_batch = dataset.load_batch(batch_size=16, augmentation=True) \ No newline at end of file diff --git a/capstone_traffic_light_classifier/train.py b/capstone_traffic_light_classifier/train.py index 6e35f83..23cefdf 100644 --- a/capstone_traffic_light_classifier/train.py +++ b/capstone_traffic_light_classifier/train.py @@ -8,21 +8,19 @@ if __name__ == '__main__': # Parameters - input_h, input_w = 64, 64 # Shape to which input is resized + input_h, input_w = 128, 128 # Shape to which input is resized # Init traffic light dataset dataset = TrafficLightDataset() - dataset_root = 'C:/Users/minotauro/Google Drive/SHARE/traffic_light_dataset' - dataset.init_from_files(dataset_root, resize=(input_h, input_w)) - dataset.dump_to_npy('traffic_light_dataset.npy') - # dataset.init_from_npy('traffic_light_dataset.npy') + dataset_file = 'traffic_light_dataset_npy/traffic_light_dataset_mixed_resize_{}.npy'.format(input_h) + dataset.init_from_npy(dataset_file) # Define model - classifier = TrafficLightClassifier(input_shape=[input_h, input_w], learning_rate=1e-4) + classifier = TrafficLightClassifier(input_shape=[input_h, input_w], learning_rate=1e-4, verbose=True) # Checkpoint stuff - saver = tf.train.Saver() # saver to save the model after each epoch - checkpoint_dir = './checkpoint_2' # checkpoint directory + saver = tf.train.Saver() # saver to save the model after each epoch + checkpoint_dir = './checkpoint_mixed_{}'.format(input_h) # checkpoint directory if not exists(checkpoint_dir): makedirs(checkpoint_dir)