Skip to content

Commit

Permalink
Refactor code for loading vgg pretrained
Browse files Browse the repository at this point in the history
  • Loading branch information
ndrplz committed Sep 4, 2017
1 parent c500094 commit b6d10ac
Showing 1 changed file with 8 additions and 9 deletions.
17 changes: 8 additions & 9 deletions project_12_road_segmentation/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,23 +26,22 @@ def load_vgg(sess, vgg_path):
:return: Tuple of Tensors from VGG model (image_input, keep_prob, layer3_out, layer4_out, layer7_out)
"""

vgg_tag = 'vgg16'
vgg_input_tensor_name = 'image_input:0'
vgg_keep_prob_tensor_name = 'keep_prob:0'
vgg_layer3_out_tensor_name = 'layer3_out:0'
vgg_layer4_out_tensor_name = 'layer4_out:0'
vgg_layer7_out_tensor_name = 'layer7_out:0'

tf.saved_model.loader.load(sess, [vgg_tag], vgg_path)
tf.saved_model.loader.load(sess, ['vgg16'], vgg_path)
graph = tf.get_default_graph()

w1 = graph.get_tensor_by_name(vgg_input_tensor_name)
keep = graph.get_tensor_by_name(vgg_keep_prob_tensor_name)
w3 = graph.get_tensor_by_name(vgg_layer3_out_tensor_name)
w4 = graph.get_tensor_by_name(vgg_layer4_out_tensor_name)
w7 = graph.get_tensor_by_name(vgg_layer7_out_tensor_name)
image_input = graph.get_tensor_by_name(vgg_input_tensor_name)
keep_prob = graph.get_tensor_by_name(vgg_keep_prob_tensor_name)
layer3_out = graph.get_tensor_by_name(vgg_layer3_out_tensor_name)
layer4_out = graph.get_tensor_by_name(vgg_layer4_out_tensor_name)
layer7_out = graph.get_tensor_by_name(vgg_layer7_out_tensor_name)

return w1, keep, w3, w4, w7
return image_input, keep_prob, layer3_out, layer4_out, layer7_out


def layers(vgg_layer3_out, vgg_layer4_out, vgg_layer7_out, num_classes):
Expand Down Expand Up @@ -125,7 +124,7 @@ def run():
# https://datascience.stackexchange.com/questions/5224/how-to-prepare-augment-images-for-neural-network

# TODO: Build NN using load_vgg, layers, and optimize function
load_vgg(sess, vgg_path)
image_input, keep_prob, layer3_out, layer4_out, layer7_out = load_vgg(sess, vgg_path)

# TODO: Train NN using the train_nn function

Expand Down

0 comments on commit b6d10ac

Please sign in to comment.