diff --git a/compiler/circle2circle-dredd-recipe-test/test.lst b/compiler/circle2circle-dredd-recipe-test/test.lst index 6f32a2966af..0fd24c88ad8 100644 --- a/compiler/circle2circle-dredd-recipe-test/test.lst +++ b/compiler/circle2circle-dredd-recipe-test/test.lst @@ -63,6 +63,7 @@ Add(Net_InstanceNorm_005 PASS fuse_instnorm) Add(Net_InstanceNorm_006 PASS fuse_instnorm) Add(Net_InstanceNorm_007 PASS fuse_instnorm) Add(Net_InstanceNorm_008 PASS fuse_instnorm) +Add(Net_InstanceNorm_009 PASS fuse_instnorm) Add(Net_Maximum_Minimum_000 PASS transform_min_max_to_relu6) Add(Net_Mul_Add_000 PASS remove_unnecessary_add) Add(Net_Mul_Add_001 PASS remove_unnecessary_add) diff --git a/compiler/luci-pass-value-py-test/test.lst b/compiler/luci-pass-value-py-test/test.lst index 7812be16991..8453962453f 100644 --- a/compiler/luci-pass-value-py-test/test.lst +++ b/compiler/luci-pass-value-py-test/test.lst @@ -42,6 +42,7 @@ eval(Net_InstanceNorm_001 fuse_instnorm) eval(Net_InstanceNorm_002 fuse_instnorm) eval(Net_InstanceNorm_003 fuse_instnorm) eval(Net_InstanceNorm_008 fuse_instnorm) +eval(Net_InstanceNorm_009 fuse_instnorm) eval(Net_Mul_Add_000 remove_unnecessary_add) eval(Net_Mul_Add_001 remove_unnecessary_add) eval(Net_Mul_Add_002 remove_unnecessary_add) diff --git a/compiler/luci/pass/src/FuseInstanceNormPass.cpp b/compiler/luci/pass/src/FuseInstanceNormPass.cpp index 5427e1fe69c..1cc8302969b 100644 --- a/compiler/luci/pass/src/FuseInstanceNormPass.cpp +++ b/compiler/luci/pass/src/FuseInstanceNormPass.cpp @@ -22,6 +22,7 @@ #include #include +#include #include #include @@ -741,6 +742,12 @@ template <> bool InstanceNormPattern::matchgraph(); @@ -1075,6 +1082,11 @@ uint32_t PostFusion::input_channel(void) if (input_rank < 1) return 0; + if (input_rank == 3) + { + // use dim 1 + return input->dim(1).value(); + } // assume channel-last return input->dim(input_rank - 1).value(); } diff --git a/res/TensorFlowLiteRecipes/Net_InstanceNorm_009/test.recipe b/res/TensorFlowLiteRecipes/Net_InstanceNorm_009/test.recipe new file mode 100644 index 00000000000..351d80172c1 --- /dev/null +++ b/res/TensorFlowLiteRecipes/Net_InstanceNorm_009/test.recipe @@ -0,0 +1,184 @@ +# +# This was copied from Net_InstanceNorm_008 +# with last dim value > 1 +# + +operand { + name: "Hole" + type: FLOAT32 + shape { + dim: 1 dim: 4 dim: 8 + } +} +operand { + name: "InstanceNorm/beta" + type: FLOAT32 + shape { + dim: 1 dim: 4 dim: 1 + } + filler { + tag: "gaussian" + arg: "0.0" + arg: "1.0" + } +} +operand { + name: "InstanceNorm/instancenorm/add/y" + type: FLOAT32 + shape { + } + filler { + tag: "explicit" + arg: "1e-06" + } +} +operand { + name: "InstanceNorm/moments/variance/reduction_indices" + type: INT32 + shape { + dim: 1 + } + filler { + tag: "explicit" + arg: "2" + } +} +operand { + name: "InstanceNorm/moments/mean" + type: FLOAT32 + shape { + dim: 1 dim: 4 dim: 1 + } +} +operand { + name: "InstanceNorm/moments/SquaredDifference" + type: FLOAT32 + shape { + dim: 1 dim: 4 dim: 8 + } +} +operand { + name: "InstanceNorm/moments/variance" + type: FLOAT32 + shape { + dim: 1 dim: 4 dim: 1 + } +} +operand { + name: "InstanceNorm/instancenorm/add" + type: FLOAT32 + shape { + dim: 1 dim: 4 dim: 8 + } +} +operand { + name: "InstanceNorm/instancenorm/Rsqrt" + type: FLOAT32 + shape { + dim: 1 dim: 4 dim: 1 + } +} +operand { + name: "InstanceNorm/instancenorm/mul_1" + type: FLOAT32 + shape { + dim: 1 dim: 4 dim: 1 + } +} +operand { + name: "InstanceNorm/instancenorm/mul_2" + type: FLOAT32 + shape { + dim: 1 dim: 4 dim: 8 + } +} +operand { + name: "InstanceNorm/instancenorm/sub" + type: FLOAT32 + shape { + dim: 1 dim: 4 dim: 1 + } +} +operand { + name: "InstanceNorm/instancenorm/add_1" + type: FLOAT32 + shape { + dim: 1 dim: 4 dim: 8 + } +} +operation { + type: "Mean" + input: "Hole" + input: "InstanceNorm/moments/variance/reduction_indices" + output: "InstanceNorm/moments/mean" + mean_options { + keep_dims: true + } +} +operation { + type: "SquaredDifference" + input: "Hole" + input: "InstanceNorm/moments/mean" + output: "InstanceNorm/moments/SquaredDifference" +} +operation { + type: "Mean" + input: "InstanceNorm/moments/SquaredDifference" + input: "InstanceNorm/moments/variance/reduction_indices" + output: "InstanceNorm/moments/variance" + mean_options { + keep_dims: true + } +} +operation { + type: "Add" + input: "InstanceNorm/moments/variance" + input: "InstanceNorm/instancenorm/add/y" + output: "InstanceNorm/instancenorm/add" + add_options { + activation: NONE + } +} +operation { + type: "Rsqrt" + input: "InstanceNorm/instancenorm/add" + output: "InstanceNorm/instancenorm/Rsqrt" +} +operation { + type: "Mul" + input: "Hole" + input: "InstanceNorm/instancenorm/Rsqrt" + output: "InstanceNorm/instancenorm/mul_1" + mul_options { + activation: NONE + } +} +operation { + type: "Mul" + input: "InstanceNorm/moments/mean" + input: "InstanceNorm/instancenorm/Rsqrt" + output: "InstanceNorm/instancenorm/mul_2" + mul_options { + activation: NONE + } +} +operation { + type: "Sub" + input: "InstanceNorm/beta" + input: "InstanceNorm/instancenorm/mul_2" + output: "InstanceNorm/instancenorm/sub" + sub_options { + activation: NONE + } +} +operation { + type: "Add" + input: "InstanceNorm/instancenorm/mul_1" + input: "InstanceNorm/instancenorm/sub" + output: "InstanceNorm/instancenorm/add_1" + add_options { + activation: NONE + } +} +input: "Hole" +output: "InstanceNorm/instancenorm/add_1" diff --git a/res/TensorFlowLiteRecipes/Net_InstanceNorm_009/test.rule b/res/TensorFlowLiteRecipes/Net_InstanceNorm_009/test.rule new file mode 100644 index 00000000000..e8af35f05c3 --- /dev/null +++ b/res/TensorFlowLiteRecipes/Net_InstanceNorm_009/test.rule @@ -0,0 +1,13 @@ +# To check if this network is converted to circle InstanceNorm op + +RULE "VERIFY_FILE_FORMAT" $(verify_file_format) '=' 1 + +RULE "INSTANCE_NORM_EXIST" $(op_count INSTANCE_NORM) '=' 1 +RULE "NO_ADD" $(op_count ADD) '=' 0 +RULE "NO_MUL" $(op_count MUL) '=' 0 +RULE "NO_POW" $(op_count POW) '=' 0 +RULE "NO_DIV" $(op_count DIV) '=' 0 +RULE "NO_SQUARED_DIFF" $(op_count SQUARED_DIFFERENCE) '=' 0 +RULE "NO_MEAN" $(op_count MEAN) '=' 0 +RULE "NO_RSQRT" $(op_count RSQRT) '=' 0 +RULE "NO_SUB" $(op_count SUB) '=' 0