Skip to content

Commit

Permalink
Refactor test to work with new classifier API
Browse files Browse the repository at this point in the history
  • Loading branch information
ndrplz committed Oct 15, 2017
1 parent e80ca67 commit d270db7
Showing 1 changed file with 3 additions and 10 deletions.
13 changes: 3 additions & 10 deletions capstone_traffic_light_classifier/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,19 +63,11 @@ def read_and_resize_image(image_path):
# Parse command line arguments
args = parse_arguments()

# Parameters
n_classes = 4 # Namely `void`, `red`, `yellow`, `green`

# Load data on which prediction will be performed
x_batch, y_batch = load_test_data(args)

# Placeholders
x = tf.placeholder(dtype=tf.float32, shape=[None, args.resize_h, args.resize_w, 3]) # input placeholder
p = tf.placeholder(dtype=tf.float32) # dropout keep probability
targets = tf.placeholder(dtype=tf.int32, shape=[None])

# Define model
classifier = TrafficLightClassifier(x, targets, p, n_classes, learning_rate=1e-4)
classifier = TrafficLightClassifier(input_shape=[args.resize_h, args.resize_w], learning_rate=1e-4)

# Add a saver to save the model after each epoch
saver = tf.train.Saver()
Expand All @@ -86,7 +78,8 @@ def read_and_resize_image(image_path):
saver.restore(sess, args.checkpoint_path)

# Predict on loaded batch
prediction = sess.run(fetches=classifier.inference, feed_dict={x: x_batch, p: 1.})
prediction = sess.run(fetches=classifier.inference,
feed_dict={classifier.x: x_batch, classifier.keep_prob: 1.})
prediction = np.argmax(prediction, axis=1) # from onehot vectors to labels

# Qualitatively show results
Expand Down

0 comments on commit d270db7

Please sign in to comment.