Skip to content
This repository has been archived by the owner on Oct 13, 2021. It is now read-only.

Commit

Permalink
Match output_name with keras model (#282)
Browse files Browse the repository at this point in the history
* Match output_name with keras
  • Loading branch information
jiafatom authored Oct 24, 2019
1 parent 50426a5 commit 446b606
Showing 1 changed file with 31 additions and 3 deletions.
34 changes: 31 additions & 3 deletions keras2onnx/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,12 +708,9 @@ def _parse_graph_core(graph, keras_node_dict, topology, top_scope, output_names)
:return: The whole topology of the intermediate objects.
"""
input_nodes = set()
raw_model_container = topology.raw_model

# build the node in the working scope.
varset = topology.declare_scope('curr_', top_scope)
for name in output_names:
raw_model_container.add_output_name(name)

model_outputs = []
for name in output_names:
Expand Down Expand Up @@ -797,4 +794,35 @@ def parse_graph(topo, graph, target_opset, output_names):
op.add_output(var1)
topo.raw_model.add_input_name(str_value)

output_name_dict = {}
for idx_, ts_ in enumerate(topo.raw_model.model.outputs):
op = top_level.declare_local_operator(TYPES.Identity)
output_ts = topo.raw_model.model.outputs[idx_]
var_type = _adjust_input_batch_size(_infer_variable_type(output_ts, target_opset))
str_value = output_ts.name
use_ts_name = False
if hasattr(topo.raw_model.model, 'output_names'):
str_value = topo.raw_model.model.output_names[idx_]
elif topo.raw_model.model.outputs[idx_].name.endswith(':0'):
str_value = topo.raw_model.model.outputs[idx_].name[:-2]
else:
# if there is no difference between output tensor name and model output name
# skip it.
use_ts_name = True

if str_value in output_name_dict:
cur_count = output_name_dict[str_value]
output_name_dict[str_value] = cur_count + 1
str_value = str_value + ':' + str(cur_count)
else:
output_name_dict[str_value] = 1

if not use_ts_name:
var0 = top_level.get_local_variable_or_declare_one(str_value, var_type)
var1 = top_level.get_local_variable_or_declare_one(topo.raw_model.model.outputs[idx_].name, var_type)
op.add_input(var1)
op.add_output(var0)

topo.raw_model.add_output_name(str_value)

return _parse_graph_core(graph, keras_layer_ts_map, topo, top_level, output_names)

0 comments on commit 446b606

Please sign in to comment.