Skip to content

Commit

Permalink
Refactor model code to include placeholders
Browse files Browse the repository at this point in the history
  • Loading branch information
ndrplz committed Oct 15, 2017
1 parent 4bd851d commit e80ca67
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 22 deletions.
14 changes: 8 additions & 6 deletions capstone_traffic_light_classifier/traffic_light_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@

class TrafficLightClassifier:

def __init__(self, x, targets, keep_prob, n_classes, learning_rate):
def __init__(self, input_shape, learning_rate):

self.x = x
self.targets = targets
self.keep_prob = keep_prob
# Placeholders
input_h, input_w = input_shape
self.x = tf.placeholder(dtype=tf.float32, shape=[None, input_h, input_w, 3]) # input placeholder
self.targets = tf.placeholder(dtype=tf.int32, shape=[None])
self.keep_prob = tf.placeholder(dtype=tf.float32) # dropout keep probability

self.n_classes = n_classes # {void, red, yellow, green}
self.n_classes = 4 # {void, red, yellow, green}
self.learning_rate = learning_rate # learning rate used in train step

self._inference = None
Expand Down Expand Up @@ -68,7 +70,7 @@ def loss(self):
def train_step(self):
if self._train_step is None:
with tf.variable_scope('training'):
self._train_step = tf.train.AdamOptimizer(learning_rate=1e-4).minimize(self.loss)
self._train_step = tf.train.AdamOptimizer(learning_rate=self.learning_rate).minimize(self.loss)
return self._train_step

@property
Expand Down
35 changes: 19 additions & 16 deletions capstone_traffic_light_classifier/train.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,30 @@
import tensorflow as tf
from os import makedirs
from os.path import exists, join
from traffic_light_dataset import TrafficLightDataset
from traffic_light_classifier import TrafficLightClassifier


if __name__ == '__main__':

# Parameters
n_classes = 4 # Namely `void`, `red`, `yellow`, `green`
input_h, input_w = 64, 64 # Shape to which input is resized

# Init traffic light dataset
dataset = TrafficLightDataset()
# dataset_root = 'C:/Users/minotauro/Desktop/traffic_light_dataset'
# dataset.init_from_files(dataset_root, resize=(input_h, input_w))
# dataset.dump_to_npy('traffic_light_dataset.npy')
dataset.init_from_npy('traffic_light_dataset.npy')

# Placeholders
x = tf.placeholder(dtype=tf.float32, shape=[None, input_h, input_w, 3]) # input placeholder
p = tf.placeholder(dtype=tf.float32) # dropout keep probability
targets = tf.placeholder(dtype=tf.int32, shape=[None])
dataset_root = 'C:/Users/minotauro/Google Drive/SHARE/traffic_light_dataset'
dataset.init_from_files(dataset_root, resize=(input_h, input_w))
dataset.dump_to_npy('traffic_light_dataset.npy')
# dataset.init_from_npy('traffic_light_dataset.npy')

# Define model
classifier = TrafficLightClassifier(x, targets, p, n_classes, learning_rate=1e-4)
classifier = TrafficLightClassifier(input_shape=[input_h, input_w], learning_rate=1e-4)

# Add a saver to save the model after each epoch
saver = tf.train.Saver()
# Checkpoint stuff
saver = tf.train.Saver() # saver to save the model after each epoch
checkpoint_dir = './checkpoint_2' # checkpoint directory
if not exists(checkpoint_dir):
makedirs(checkpoint_dir)

with tf.Session() as sess:

Expand All @@ -49,7 +48,9 @@

# Actually run one training step here
_, loss_this_batch = sess.run(fetches=[classifier.train_step, classifier.loss],
feed_dict={x: x_batch, targets: y_batch, p: 0.5})
feed_dict={classifier.x: x_batch,
classifier.targets: y_batch,
classifier.keep_prob: 0.5})

loss_cur_epoch += loss_this_batch

Expand All @@ -62,12 +63,14 @@
for _ in range(num_test_batches):
x_batch, y_batch = dataset.load_batch(batch_size)
average_test_accuracy += sess.run(fetches=classifier.accuracy,
feed_dict={x: x_batch, targets: y_batch, p: 1.})
feed_dict={classifier.x: x_batch,
classifier.targets: y_batch,
classifier.keep_prob: 1.0})
average_test_accuracy /= num_test_batches
print('Training accuracy: {:.03f}'.format(average_test_accuracy))
print('*' * 50)

# Save the variables to disk.
save_path = saver.save(sess, './checkpoints/model_epoch_{}.ckpt'.format(epoch))
save_path = saver.save(sess, join(checkpoint_dir, 'TLC_epoch_{}.ckpt'.format(epoch)))

epoch += 1

0 comments on commit e80ca67

Please sign in to comment.