diff --git a/compiler/luci/service/include/luci/Service/CircleShapeInference.h b/compiler/luci/service/include/luci/Service/CircleShapeInference.h index b9be778a908..0f3bfdf7855 100644 --- a/compiler/luci/service/include/luci/Service/CircleShapeInference.h +++ b/compiler/luci/service/include/luci/Service/CircleShapeInference.h @@ -122,7 +122,7 @@ class Algorithm final : public luci::CircleNodeVisitor // loco::TensorShape visit(const luci::CircleRelu0To1 *node) final; // loco::TensorShape visit(const luci::CircleRelu6 *node) final; // loco::TensorShape visit(const luci::CircleReluN1To1 *node) final; - // loco::TensorShape visit(const luci::CircleReshape *node) final; + loco::TensorShape visit(const luci::CircleReshape *node) final; // loco::TensorShape visit(const luci::CircleResizeBilinear *node) final; // loco::TensorShape visit(const luci::CircleResizeNearestNeighbor *node) final; // loco::TensorShape visit(const luci::CircleReverseSequence *node) final; diff --git a/compiler/luci/service/src/CircleShapeInferenceRule.cpp b/compiler/luci/service/src/CircleShapeInferenceRule.cpp index 4fa8405a5bc..b6c5d0d5bb7 100644 --- a/compiler/luci/service/src/CircleShapeInferenceRule.cpp +++ b/compiler/luci/service/src/CircleShapeInferenceRule.cpp @@ -911,84 +911,6 @@ loco::NodeShape infer_p_relu(const luci::CirclePRelu *node) return loco::NodeShape{output_shape}; } -loco::NodeShape infer_reshape(const luci::CircleReshape *node) -{ - LOGGER(l); - - const loco::DataType S32 = loco::DataType::S32; - - loco::TensorShape shape_by_input; - { - LUCI_ASSERT(node->shape(), "2nd input shape() should not be nullptr"); - - // Only support node's shape() is CircleConst with S32 - // TODO support other node with other types - auto const_shape_node = dynamic_cast(node->shape()); - if (const_shape_node != nullptr) - { - LUCI_ASSERT(const_shape_node->dtype() == S32, "Only support int32 CircleConst"); - - shape_by_input.rank(const_shape_node->size()); - - for (uint32_t axis = 0; axis < shape_by_input.rank(); ++axis) - { - shape_by_input.dim(axis) = const_shape_node->at(axis); - } - } - else - { - // We use shape from the node itself - shape_by_input = own_shape(node); - } - } - - loco::TensorShape shape_by_attr; - { - shape_by_attr.rank(node->newShape()->rank()); - - for (uint32_t axis = 0; axis < shape_by_attr.rank(); ++axis) - { - shape_by_attr.dim(axis) = node->newShape()->dim(axis); - } - } - - if (!(shape_by_input == shape_by_attr)) - { - INFO(l) << "CircleReshape: Two new shape information mismatched : " << std::endl; - INFO(l) << " shape_by_input : " << shape_by_input << std::endl; - INFO(l) << " shape_by_attr : " << shape_by_attr << std::endl; - } - - loco::TensorShape output_shape = shape_by_input; - - // One of the dimensions can have special value -1, meaning its actual value should be inferred. - const auto input_shape = luci::shape_get(node->tensor()).as(); - uint32_t input_element_count = 1; - uint32_t output_element_count = 1; - uint32_t unknown_dim_index = UINT32_MAX; - for (uint32_t i = 0; i < input_shape.rank(); ++i) - input_element_count *= (input_shape.dim(i).known() ? input_shape.dim(i).value() : 1); - for (uint32_t dim_index = 0; dim_index < output_shape.rank(); ++dim_index) - { - const uint32_t dim_value = output_shape.dim(dim_index).value(); - if (static_cast(dim_value) == -1) - { - LUCI_ASSERT(unknown_dim_index == UINT32_MAX, "More than one unknown dimension"); - unknown_dim_index = dim_index; - } - else - { - output_element_count *= dim_value; - } - } - if (unknown_dim_index != UINT32_MAX) - { - output_shape.dim(unknown_dim_index) = input_element_count / output_element_count; - } - - return loco::NodeShape{output_shape}; -} - template loco::NodeShape infer_resize_type(const CIRCLENODE *node) { auto input_shape = luci::shape_get(node->input()).template as(); @@ -2121,15 +2043,6 @@ class ShapeInferenceAlgorithm final : public luci::CircleNodeVisitor + +namespace +{ + +std::ostream &operator<<(std::ostream &os, const loco::TensorShape &tensor_shape) +{ + os << "["; + for (uint32_t r = 0; r < tensor_shape.rank(); ++r) + { + if (r) + os << ","; + + if (tensor_shape.dim(r).known()) + os << tensor_shape.dim(r).value(); + else + os << "?"; + } + os << "]"; + return os; +} + +} // namespace + namespace luci { @@ -34,4 +62,101 @@ luci::CircleNode *CloneNodeLet::visit(const luci::CircleReshape *node) return cloned; } +namespace sinf +{ + +loco::TensorShape Algorithm::visit(const luci::CircleReshape *node) +{ + LOGGER(l); + + const loco::DataType S32 = loco::DataType::S32; + + loco::TensorShape shape_by_input; + { + LUCI_ASSERT(node->shape(), "2nd input shape() should not be nullptr"); + + // Only support node's shape() is CircleConst with S32 + // TODO support other node with other types + auto const_shape_node = dynamic_cast(node->shape()); + if (const_shape_node != nullptr) + { + LUCI_ASSERT(const_shape_node->dtype() == S32, "Only support int32 CircleConst"); + + shape_by_input.rank(const_shape_node->size()); + + for (uint32_t axis = 0; axis < shape_by_input.rank(); ++axis) + { + shape_by_input.dim(axis) = const_shape_node->at(axis); + } + } + else + { + // We use shape from the node itself + loco::TensorShape shape; + shape.rank(node->rank()); + for (uint32_t r = 0; r < node->rank(); ++r) + { + // TODO remove this copy from `use_own(node);` + // Shape inference rules in this file did not consider unknown dimension. + // If some node has unknown dimension, 0 is inserted and wrong shape + // inference was done as a result. + // To fix this, new shape inference algorithm is being implemented. + // Until new inference algorithm is fully implemented, unknown dimension + // would be represented as 1 along with TFLite expression. + shape.dim(r) = node->dim(r).known() ? node->dim(r).value() : 1; + } + shape_by_input = shape; + } + } + + loco::TensorShape shape_by_attr; + { + shape_by_attr.rank(node->newShape()->rank()); + + for (uint32_t axis = 0; axis < shape_by_attr.rank(); ++axis) + { + shape_by_attr.dim(axis) = node->newShape()->dim(axis); + } + } + + if (!(shape_by_input == shape_by_attr)) + { + INFO(l) << "CircleReshape: Two new shape information mismatched : " << std::endl; + INFO(l) << " shape_by_input : " << shape_by_input << std::endl; + INFO(l) << " shape_by_attr : " << shape_by_attr << std::endl; + } + + loco::TensorShape output_shape = shape_by_input; + + // One of the dimensions can have special value -1, meaning its actual value should be inferred. + const auto input = loco::must_cast(node->tensor()); + const auto input_shape = circle_shape(input); + uint32_t input_element_count = 1; + uint32_t output_element_count = 1; + uint32_t unknown_dim_index = UINT32_MAX; + for (uint32_t i = 0; i < input_shape.rank(); ++i) + input_element_count *= (input_shape.dim(i).known() ? input_shape.dim(i).value() : 1); + for (uint32_t dim_index = 0; dim_index < output_shape.rank(); ++dim_index) + { + const uint32_t dim_value = output_shape.dim(dim_index).value(); + if (static_cast(dim_value) == -1) + { + LUCI_ASSERT(unknown_dim_index == UINT32_MAX, "More than one unknown dimension"); + unknown_dim_index = dim_index; + } + else + { + output_element_count *= dim_value; + } + } + if (unknown_dim_index != UINT32_MAX) + { + output_shape.dim(unknown_dim_index) = input_element_count / output_element_count; + } + + return output_shape; +} + +} // namespace sinf + } // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleReshape.test.cpp b/compiler/luci/service/src/Nodes/CircleReshape.test.cpp index ca92b717d0c..a6ae6735500 100644 --- a/compiler/luci/service/src/Nodes/CircleReshape.test.cpp +++ b/compiler/luci/service/src/Nodes/CircleReshape.test.cpp @@ -15,6 +15,7 @@ */ #include "luci/Service/CircleNodeClone.h" +#include "luci/Service/CircleShapeInference.h" #include @@ -37,3 +38,67 @@ TEST(CloneNodeTest, clone_Reshape) ASSERT_EQ(node_reshape->newShape()->dim(0), cloned_reshape->newShape()->dim(0)); ASSERT_EQ(node_reshape->newShape()->dim(1), cloned_reshape->newShape()->dim(1)); } + +TEST(ShapeRuleTest, reshape_by_input_const_static) +{ + auto g = loco::make_graph(); + auto node_reshape = g->nodes()->create(); + auto tensor_input = g->nodes()->create(); + auto shape_by_input = g->nodes()->create(); + + tensor_input->dtype(loco::DataType::S32); + tensor_input->shape({2, 3, 4}); + tensor_input->shape_status(luci::ShapeStatus::VALID); + + shape_by_input->dtype(loco::DataType::S32); + shape_by_input->size(2); + shape_by_input->at(0) = 6; + shape_by_input->at(1) = 4; + shape_by_input->shape_status(luci::ShapeStatus::VALID); + + node_reshape->tensor(tensor_input); + node_reshape->shape(shape_by_input); + + loco::TensorShape output_shape; + luci::sinf::Rule shape_inf_rule; + + ASSERT_TRUE(shape_inf_rule.infer(node_reshape, output_shape)); + + ASSERT_EQ(2, output_shape.rank()); + ASSERT_TRUE(output_shape.dim(0).known()); + ASSERT_TRUE(output_shape.dim(1).known()); + ASSERT_EQ(6, output_shape.dim(0).value()); + ASSERT_EQ(4, output_shape.dim(1).value()); +} + +TEST(ShapeRuleTest, reshape_by_input_const_dynamic) +{ + auto g = loco::make_graph(); + auto node_reshape = g->nodes()->create(); + auto tensor_input = g->nodes()->create(); + auto shape_by_input = g->nodes()->create(); + + tensor_input->dtype(loco::DataType::S32); + tensor_input->shape({2, 3, 4}); + tensor_input->shape_status(luci::ShapeStatus::VALID); + + shape_by_input->dtype(loco::DataType::S32); + shape_by_input->size(2); + shape_by_input->at(0) = -1; + shape_by_input->at(1) = 4; + shape_by_input->shape_status(luci::ShapeStatus::VALID); + + node_reshape->tensor(tensor_input); + node_reshape->shape(shape_by_input); + + loco::TensorShape output_shape; + luci::sinf::Rule shape_inf_rule; + + ASSERT_TRUE(shape_inf_rule.infer(node_reshape, output_shape)); + + ASSERT_EQ(2, output_shape.rank()); + ASSERT_TRUE(output_shape.dim(0).known()); + ASSERT_TRUE(output_shape.dim(1).known()); + ASSERT_EQ(6, output_shape.dim(0).value()); + ASSERT_EQ(4, output_shape.dim(1).value()); +}