From 1c43700bb515f78f5789c07481ac376d5e852ebb Mon Sep 17 00:00:00 2001 From: ndrplz Date: Sun, 15 Oct 2017 11:42:33 +0200 Subject: [PATCH] Refactor code into train and test --- capstone_traffic_light_classifier/test.py | 53 +++++++++++++++++++ .../{main.py => train.py} | 13 ++++- 2 files changed, 65 insertions(+), 1 deletion(-) create mode 100644 capstone_traffic_light_classifier/test.py rename capstone_traffic_light_classifier/{main.py => train.py} (85%) diff --git a/capstone_traffic_light_classifier/test.py b/capstone_traffic_light_classifier/test.py new file mode 100644 index 0000000..173b5bc --- /dev/null +++ b/capstone_traffic_light_classifier/test.py @@ -0,0 +1,53 @@ +import cv2 +import numpy as np +import tensorflow as tf +from traffic_light_dataset import TrafficLightDataset +from traffic_light_classifier import TrafficLightClassifier + + +checkpoint_file = './checkpoints/model_epoch_2.ckpt' + + +if __name__ == '__main__': + + # Parameters + n_classes = 4 # Namely `void`, `red`, `yellow`, `green` + input_h, input_w = 64, 64 # Shape to which input is resized + + # Init traffic light dataset + dataset = TrafficLightDataset() + dataset.init_from_npy('traffic_light_dataset.npy') + + # Placeholders + x = tf.placeholder(dtype=tf.float32, shape=[None, input_h, input_w, 3]) # input placeholder + p = tf.placeholder(dtype=tf.float32) # dropout keep probability + targets = tf.placeholder(dtype=tf.int32, shape=[None]) + + # Define model + classifier = TrafficLightClassifier(x, targets, p, n_classes, learning_rate=1e-4) + + # Add a saver to save the model after each epoch + saver = tf.train.Saver() + + with tf.Session() as sess: + + # Restore pretrained weights + saver.restore(sess, checkpoint_file) + + # Load a batch of data to test the model + x_batch, y_batch = dataset.load_batch(batch_size=16) + + # Predict on loaded batch + prediction = sess.run(fetches=classifier.inference, feed_dict={x: x_batch, targets: y_batch, p: 1.}) + prediction = np.argmax(prediction, axis=1) # from onehot vectors to labels + + # Revert data normalization + x_batch += np.abs(np.min(x_batch)) + x_batch *= 255 + x_batch = np.clip(x_batch, 0, 255).astype(np.uint8) + + # Qualitatively show results + for b in range(x_batch.shape[0]): + image = cv2.resize(x_batch[b], (256, 256)) + cv2.imshow('PRED {} GT {}'.format(prediction[b], y_batch[b]), image) + cv2.waitKey() diff --git a/capstone_traffic_light_classifier/main.py b/capstone_traffic_light_classifier/train.py similarity index 85% rename from capstone_traffic_light_classifier/main.py rename to capstone_traffic_light_classifier/train.py index 67264c1..5ce2f21 100644 --- a/capstone_traffic_light_classifier/main.py +++ b/capstone_traffic_light_classifier/train.py @@ -11,7 +11,8 @@ # Init traffic light dataset dataset = TrafficLightDataset() - # dataset.init_from_files('../traffic_light_dataset', resize=(input_h, input_w)) + # dataset_root = 'C:/Users/minotauro/Desktop/traffic_light_dataset' + # dataset.init_from_files(dataset_root, resize=(input_h, input_w)) # dataset.dump_to_npy('traffic_light_dataset.npy') dataset.init_from_npy('traffic_light_dataset.npy') @@ -23,6 +24,9 @@ # Define model classifier = TrafficLightClassifier(x, targets, p, n_classes, learning_rate=1e-4) + # Add a saver to save the model after each epoch + saver = tf.train.Saver() + with tf.Session() as sess: # Initialize all variables @@ -32,6 +36,8 @@ batch_size = 32 batches_each_epoch = 1000 + epoch = 0 + while True: loss_cur_epoch = 0 @@ -60,3 +66,8 @@ average_test_accuracy /= num_test_batches print('Training accuracy: {:.03f}'.format(average_test_accuracy)) print('*' * 50) + + # Save the variables to disk. + save_path = saver.save(sess, './checkpoints/model_epoch_{}.ckpt'.format(epoch)) + + epoch += 1