Skip to content

Commit

Permalink
Refactor code into train and test
Browse files Browse the repository at this point in the history
  • Loading branch information
ndrplz committed Oct 15, 2017
1 parent 1116aaa commit 1c43700
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 1 deletion.
53 changes: 53 additions & 0 deletions capstone_traffic_light_classifier/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import cv2
import numpy as np
import tensorflow as tf
from traffic_light_dataset import TrafficLightDataset
from traffic_light_classifier import TrafficLightClassifier


checkpoint_file = './checkpoints/model_epoch_2.ckpt'


if __name__ == '__main__':

# Parameters
n_classes = 4 # Namely `void`, `red`, `yellow`, `green`
input_h, input_w = 64, 64 # Shape to which input is resized

# Init traffic light dataset
dataset = TrafficLightDataset()
dataset.init_from_npy('traffic_light_dataset.npy')

# Placeholders
x = tf.placeholder(dtype=tf.float32, shape=[None, input_h, input_w, 3]) # input placeholder
p = tf.placeholder(dtype=tf.float32) # dropout keep probability
targets = tf.placeholder(dtype=tf.int32, shape=[None])

# Define model
classifier = TrafficLightClassifier(x, targets, p, n_classes, learning_rate=1e-4)

# Add a saver to save the model after each epoch
saver = tf.train.Saver()

with tf.Session() as sess:

# Restore pretrained weights
saver.restore(sess, checkpoint_file)

# Load a batch of data to test the model
x_batch, y_batch = dataset.load_batch(batch_size=16)

# Predict on loaded batch
prediction = sess.run(fetches=classifier.inference, feed_dict={x: x_batch, targets: y_batch, p: 1.})
prediction = np.argmax(prediction, axis=1) # from onehot vectors to labels

# Revert data normalization
x_batch += np.abs(np.min(x_batch))
x_batch *= 255
x_batch = np.clip(x_batch, 0, 255).astype(np.uint8)

# Qualitatively show results
for b in range(x_batch.shape[0]):
image = cv2.resize(x_batch[b], (256, 256))
cv2.imshow('PRED {} GT {}'.format(prediction[b], y_batch[b]), image)
cv2.waitKey()
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@

# Init traffic light dataset
dataset = TrafficLightDataset()
# dataset.init_from_files('../traffic_light_dataset', resize=(input_h, input_w))
# dataset_root = 'C:/Users/minotauro/Desktop/traffic_light_dataset'
# dataset.init_from_files(dataset_root, resize=(input_h, input_w))
# dataset.dump_to_npy('traffic_light_dataset.npy')
dataset.init_from_npy('traffic_light_dataset.npy')

Expand All @@ -23,6 +24,9 @@
# Define model
classifier = TrafficLightClassifier(x, targets, p, n_classes, learning_rate=1e-4)

# Add a saver to save the model after each epoch
saver = tf.train.Saver()

with tf.Session() as sess:

# Initialize all variables
Expand All @@ -32,6 +36,8 @@
batch_size = 32
batches_each_epoch = 1000

epoch = 0

while True:

loss_cur_epoch = 0
Expand Down Expand Up @@ -60,3 +66,8 @@
average_test_accuracy /= num_test_batches
print('Training accuracy: {:.03f}'.format(average_test_accuracy))
print('*' * 50)

# Save the variables to disk.
save_path = saver.save(sess, './checkpoints/model_epoch_{}.ckpt'.format(epoch))

epoch += 1

0 comments on commit 1c43700

Please sign in to comment.