Skip to content

Commit

Permalink
Implement sketch of training pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
ndrplz committed Sep 6, 2017
1 parent 2e221f1 commit 9a0698a
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 36 deletions.
64 changes: 34 additions & 30 deletions project_12_road_segmentation/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,47 +103,55 @@ def optimize(net_prediction, labels, learning_rate, num_classes):
return logits_flat, train_step, cross_entropy_loss


def train_nn(sess, epochs, batch_size, get_batches_fn, train_op, cross_entropy_loss, input_image,
correct_label, keep_prob, learning_rate):
def train_nn(sess, training_epochs, batch_size, get_batches_fn, train_op, cross_entropy_loss,
image_input, labels, keep_prob, learning_rate):
"""
Train neural network and print out the loss during training.
:param sess: TF Session
:param epochs: Number of epochs
:param training_epochs: Number of epochs
:param batch_size: Batch size
:param get_batches_fn: Function to get batches of training data. Call using get_batches_fn(batch_size)
:param train_op: TF Operation to train the neural network
:param cross_entropy_loss: TF Tensor for the amount of loss
:param input_image: TF Placeholder for input images
:param correct_label: TF Placeholder for label images
:param image_input: TF Placeholder for input images
:param labels: TF Placeholder for label images
:param keep_prob: TF Placeholder for dropout keep probability
:param learning_rate: TF Placeholder for learning rate
"""
# TODO: Implement function
pass

# Variable initialization
sess.run(tf.global_variables_initializer())

lr = 1e-4

for e in range(0, training_epochs):

# Load a batch of examples
batch_x, batch_y = next(get_batches_fn(batch_size))

_, cur_loss = sess.run(fetches=[train_op, cross_entropy_loss],
feed_dict={image_input: batch_x, labels: batch_y, keep_prob: 0.5, learning_rate: lr})

print(cur_loss)


def perform_tests():
tests.test_for_kitti_dataset(data_dir)
tests.test_load_vgg(load_vgg, tf)
tests.test_layers(layers)
tests.test_optimize(optimize)
# tests.test_train_nn(train_nn)
pass
tests.test_train_nn(train_nn)


def run():

num_classes = 2

image_h, image_w = (160, 576)
data_dir = '/home/minotauro/code/self-driving-car/project_12_road_segmentation/data'
runs_dir = '/home/minotauro/code/self-driving-car/project_12_road_segmentation/runs'
tests.test_for_kitti_dataset(data_dir)

# Download pretrained vgg model
helper.maybe_download_pretrained_vgg(data_dir)

# OPTIONAL: Train and Inference on the cityscapes dataset instead of the Kitti dataset.
# You'll need a GPU with at least 10 teraFLOPS to train on.
# https://www.cityscapes-dataset.com/

with tf.Session() as sess:

# Path to vgg model
Expand All @@ -152,38 +160,34 @@ def run():
# Create function to get batches
batch_generator = helper.gen_batch_function(join(data_dir, 'data_road/training'), (image_h, image_w))

# OPTIONAL: Augment Images for better results
# https://datascience.stackexchange.com/questions/5224/how-to-prepare-augment-images-for-neural-network

# TODO: Build NN using load_vgg, layers, and optimize function
x, y = next(batch_generator(batch_size=1))

# Load VGG pretrained
image_input, keep_prob, vgg_layer3_out, vgg_layer4_out, vgg_layer7_out = load_vgg(sess, vgg_path)

# Add skip connections
output = layers(vgg_layer3_out, vgg_layer4_out, vgg_layer7_out, num_classes)

# Variable initialization
sess.run(tf.global_variables_initializer())

# Define placeholders
labels = tf.placeholder(tf.float32, shape=[None, image_h, image_w, num_classes])
learning_rate = tf.placeholder(tf.float32, shape=[])

logits, train_op, cross_entropy_loss = optimize(output, labels, learning_rate, num_classes)

sess.run(output, feed_dict={image_input: x, keep_prob: 1.0})
pass
# TODO: Train NN using the train_nn function
# Training parameters
training_epochs = 100
batch_size = 1

train_nn(sess, training_epochs, batch_size, batch_generator, train_op, cross_entropy_loss,
image_input, labels, keep_prob, learning_rate)

# TODO: Save inference data using helper.save_inference_samples
# helper.save_inference_samples(runs_dir, data_dir, sess, image_shape, logits, keep_prob, input_image)

# OPTIONAL: Apply the trained model to a video


if __name__ == '__main__':

data_dir = '/home/minotauro/code/self-driving-car/project_12_road_segmentation/data'
runs_dir = '/home/minotauro/code/self-driving-car/project_12_road_segmentation/runs'

perform_tests()

run()
12 changes: 6 additions & 6 deletions project_12_road_segmentation/project_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,9 @@ def test_train_nn(train_nn):
epochs = 1
batch_size = 2

def get_batches_fn(batach_size_parm):
shape = [batach_size_parm, 2, 3, 3]
return np.arange(np.prod(shape)).reshape(shape)
def get_batches_fn(batch_size_parm):
shape = [batch_size_parm, 2, 3, 3]
yield np.arange(np.prod(shape)).reshape(shape)

train_op = tf.constant(0)
cross_entropy_loss = tf.constant(10.11)
Expand All @@ -128,13 +128,13 @@ def get_batches_fn(batach_size_parm):
with tf.Session() as sess:
parameters = {
'sess': sess,
'epochs': epochs,
'training_epochs': epochs,
'batch_size': batch_size,
'get_batches_fn': get_batches_fn,
'train_op': train_op,
'cross_entropy_loss': cross_entropy_loss,
'input_image': input_image,
'correct_label': correct_label,
'image_input': input_image,
'labels': correct_label,
'keep_prob': keep_prob,
'learning_rate': learning_rate}
_prevent_print(train_nn, parameters)
Expand Down

0 comments on commit 9a0698a

Please sign in to comment.