Skip to content

Commit

Permalink
Refactor code to work with newer dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
ndrplz committed Oct 20, 2017
1 parent 3464dbf commit 89e8d28
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 31 deletions.
38 changes: 15 additions & 23 deletions capstone_traffic_light_classifier/traffic_light_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,21 @@ def load_batch(self, batch_size, augmentation=False):

return X_batch, Y_batch

def print_statistics(self):
"""
Print simple statistics on the number of samples in the dataset.
:return:
"""
if not self.initialized:
raise IOError('Please initialize dataset first.')

color2label = {'none': 0, 'red': 1, 'yellow': 2, 'green': 3}

statistics = {}
for (color, num_label) in color2label.items():
statistics[color] = np.sum(self.dataset_npy[:, 1] == num_label)
print(statistics)

@staticmethod
def preprocess(x):
"""
Expand Down Expand Up @@ -113,26 +128,3 @@ def infer_label_from_frame_path(path):
elif 'green' in path:
label = 3
return label

def print_statistics(self):

if not self.initialized:
raise IOError('Please initialize dataset first.')

color2label = {'none': 0, 'red': 1, 'yellow': 2, 'green': 3}

statistics = {}
for (color, num_label) in color2label.items():
statistics[color] = np.sum(self.dataset_npy[:, 1] == num_label)
print(statistics)

if __name__ == '__main__':
# Init traffic light dataset
dataset = TrafficLightDataset()
# dataset_root = 'C:/Users/minotauro/Desktop/traffic_light_dataset'
# dataset.init_from_files(dataset_root, resize=(input_h, input_w))
# dataset.dump_to_npy('traffic_light_dataset.npy')
dataset.init_from_npy('traffic_light_dataset.npy')

# Load a batch of training data
x_batch, y_batch = dataset.load_batch(batch_size=16, augmentation=True)
14 changes: 6 additions & 8 deletions capstone_traffic_light_classifier/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,19 @@
if __name__ == '__main__':

# Parameters
input_h, input_w = 64, 64 # Shape to which input is resized
input_h, input_w = 128, 128 # Shape to which input is resized

# Init traffic light dataset
dataset = TrafficLightDataset()
dataset_root = 'C:/Users/minotauro/Google Drive/SHARE/traffic_light_dataset'
dataset.init_from_files(dataset_root, resize=(input_h, input_w))
dataset.dump_to_npy('traffic_light_dataset.npy')
# dataset.init_from_npy('traffic_light_dataset.npy')
dataset_file = 'traffic_light_dataset_npy/traffic_light_dataset_mixed_resize_{}.npy'.format(input_h)
dataset.init_from_npy(dataset_file)

# Define model
classifier = TrafficLightClassifier(input_shape=[input_h, input_w], learning_rate=1e-4)
classifier = TrafficLightClassifier(input_shape=[input_h, input_w], learning_rate=1e-4, verbose=True)

# Checkpoint stuff
saver = tf.train.Saver() # saver to save the model after each epoch
checkpoint_dir = './checkpoint_2' # checkpoint directory
saver = tf.train.Saver() # saver to save the model after each epoch
checkpoint_dir = './checkpoint_mixed_{}'.format(input_h) # checkpoint directory
if not exists(checkpoint_dir):
makedirs(checkpoint_dir)

Expand Down

0 comments on commit 89e8d28

Please sign in to comment.