diff --git a/capstone_traffic_light_classifier/traffic_light_classifier.py b/capstone_traffic_light_classifier/traffic_light_classifier.py index f5c9834..696de7d 100644 --- a/capstone_traffic_light_classifier/traffic_light_classifier.py +++ b/capstone_traffic_light_classifier/traffic_light_classifier.py @@ -7,13 +7,15 @@ class TrafficLightClassifier: - def __init__(self, x, targets, keep_prob, n_classes, learning_rate): + def __init__(self, input_shape, learning_rate): - self.x = x - self.targets = targets - self.keep_prob = keep_prob + # Placeholders + input_h, input_w = input_shape + self.x = tf.placeholder(dtype=tf.float32, shape=[None, input_h, input_w, 3]) # input placeholder + self.targets = tf.placeholder(dtype=tf.int32, shape=[None]) + self.keep_prob = tf.placeholder(dtype=tf.float32) # dropout keep probability - self.n_classes = n_classes # {void, red, yellow, green} + self.n_classes = 4 # {void, red, yellow, green} self.learning_rate = learning_rate # learning rate used in train step self._inference = None @@ -68,7 +70,7 @@ def loss(self): def train_step(self): if self._train_step is None: with tf.variable_scope('training'): - self._train_step = tf.train.AdamOptimizer(learning_rate=1e-4).minimize(self.loss) + self._train_step = tf.train.AdamOptimizer(learning_rate=self.learning_rate).minimize(self.loss) return self._train_step @property diff --git a/capstone_traffic_light_classifier/train.py b/capstone_traffic_light_classifier/train.py index 9830ddd..6e35f83 100644 --- a/capstone_traffic_light_classifier/train.py +++ b/capstone_traffic_light_classifier/train.py @@ -1,4 +1,6 @@ import tensorflow as tf +from os import makedirs +from os.path import exists, join from traffic_light_dataset import TrafficLightDataset from traffic_light_classifier import TrafficLightClassifier @@ -6,26 +8,23 @@ if __name__ == '__main__': # Parameters - n_classes = 4 # Namely `void`, `red`, `yellow`, `green` input_h, input_w = 64, 64 # Shape to which input is resized # Init traffic light dataset dataset = TrafficLightDataset() - # dataset_root = 'C:/Users/minotauro/Desktop/traffic_light_dataset' - # dataset.init_from_files(dataset_root, resize=(input_h, input_w)) - # dataset.dump_to_npy('traffic_light_dataset.npy') - dataset.init_from_npy('traffic_light_dataset.npy') - - # Placeholders - x = tf.placeholder(dtype=tf.float32, shape=[None, input_h, input_w, 3]) # input placeholder - p = tf.placeholder(dtype=tf.float32) # dropout keep probability - targets = tf.placeholder(dtype=tf.int32, shape=[None]) + dataset_root = 'C:/Users/minotauro/Google Drive/SHARE/traffic_light_dataset' + dataset.init_from_files(dataset_root, resize=(input_h, input_w)) + dataset.dump_to_npy('traffic_light_dataset.npy') + # dataset.init_from_npy('traffic_light_dataset.npy') # Define model - classifier = TrafficLightClassifier(x, targets, p, n_classes, learning_rate=1e-4) + classifier = TrafficLightClassifier(input_shape=[input_h, input_w], learning_rate=1e-4) - # Add a saver to save the model after each epoch - saver = tf.train.Saver() + # Checkpoint stuff + saver = tf.train.Saver() # saver to save the model after each epoch + checkpoint_dir = './checkpoint_2' # checkpoint directory + if not exists(checkpoint_dir): + makedirs(checkpoint_dir) with tf.Session() as sess: @@ -49,7 +48,9 @@ # Actually run one training step here _, loss_this_batch = sess.run(fetches=[classifier.train_step, classifier.loss], - feed_dict={x: x_batch, targets: y_batch, p: 0.5}) + feed_dict={classifier.x: x_batch, + classifier.targets: y_batch, + classifier.keep_prob: 0.5}) loss_cur_epoch += loss_this_batch @@ -62,12 +63,14 @@ for _ in range(num_test_batches): x_batch, y_batch = dataset.load_batch(batch_size) average_test_accuracy += sess.run(fetches=classifier.accuracy, - feed_dict={x: x_batch, targets: y_batch, p: 1.}) + feed_dict={classifier.x: x_batch, + classifier.targets: y_batch, + classifier.keep_prob: 1.0}) average_test_accuracy /= num_test_batches print('Training accuracy: {:.03f}'.format(average_test_accuracy)) print('*' * 50) # Save the variables to disk. - save_path = saver.save(sess, './checkpoints/model_epoch_{}.ckpt'.format(epoch)) + save_path = saver.save(sess, join(checkpoint_dir, 'TLC_epoch_{}.ckpt'.format(epoch))) epoch += 1