diff --git a/compiler/luci/service/src/Nodes/CircleReshape.cpp b/compiler/luci/service/src/Nodes/CircleReshape.cpp index 0de10960b51..1524b1b0073 100644 --- a/compiler/luci/service/src/Nodes/CircleReshape.cpp +++ b/compiler/luci/service/src/Nodes/CircleReshape.cpp @@ -20,30 +20,6 @@ #include "CircleShapeInferenceHelper.h" #include "CircleCloneNode.h" -#include - -namespace -{ - -std::ostream &operator<<(std::ostream &os, const loco::TensorShape &tensor_shape) -{ - os << "["; - for (uint32_t r = 0; r < tensor_shape.rank(); ++r) - { - if (r) - os << ","; - - if (tensor_shape.dim(r).known()) - os << tensor_shape.dim(r).value(); - else - os << "?"; - } - os << "]"; - return os; -} - -} // namespace - namespace luci { @@ -65,93 +41,121 @@ luci::CircleNode *CloneNodeLet::visit(const luci::CircleReshape *node) namespace sinf { +/** + * @note CircleReshape always has two inputs: `tensor` and `shape`. + * The `shape` input can be CircleConst, CircleOutputDummy, or CircleNode. + * - If the `shape` input is CircleConst, the shape is inferred from the constant. + * - If the `shape` input is CircleOutputDummy, the shape is inferred from + * the attribute if it exists. If the attribute does not exist, + * the shape is inferred from the node iteself. + * - If the `shape` input is CircleNode, the dynamic shape is propagated. + */ loco::TensorShape Algorithm::visit(const luci::CircleReshape *node) { - LOGGER(l); - const loco::DataType S32 = loco::DataType::S32; - loco::TensorShape shape_by_input; - { - LUCI_ASSERT(node->shape(), "2nd input shape() should not be nullptr"); + // CircleReshape node must have `shape` input + LUCI_ASSERT(node->shape(), "2nd input shape() should not be nullptr"); - // Only support node's shape() is CircleConst with S32 - // TODO support other node with other types - auto const_shape_node = dynamic_cast(node->shape()); - if (const_shape_node != nullptr) + bool should_infer = true; + loco::TensorShape output_shape; + { + // Check if `shape` is CircleConst + auto const_shape = dynamic_cast(node->shape()); + if (const_shape != nullptr) { - LUCI_ASSERT(const_shape_node->dtype() == S32, "Only support int32 CircleConst"); + LUCI_ASSERT(const_shape->dtype() == S32, "Only support int32 CircleConst"); + output_shape.rank(const_shape->size()); - shape_by_input.rank(const_shape_node->size()); - - for (uint32_t axis = 0; axis < shape_by_input.rank(); ++axis) + for (uint32_t axis = 0; axis < output_shape.rank(); ++axis) { - shape_by_input.dim(axis) = const_shape_node->at(axis); + output_shape.dim(axis) = const_shape->at(axis); + if (const_shape->at(axis) < 0) + { + output_shape.dim(axis).unset(); + } } } else { - // We use shape from the node itself - loco::TensorShape shape; - shape.rank(node->rank()); - for (uint32_t r = 0; r < node->rank(); ++r) + // Check if `shape` is CircleOutputDummy + auto dummy_shape = dynamic_cast(node->shape()); + if (dummy_shape != nullptr) { - // TODO remove this copy from `use_own(node);` - // Shape inference rules in this file did not consider unknown dimension. - // If some node has unknown dimension, 0 is inserted and wrong shape - // inference was done as a result. - // To fix this, new shape inference algorithm is being implemented. - // Until new inference algorithm is fully implemented, unknown dimension - // would be represented as 1 along with TFLite expression. - shape.dim(r) = node->dim(r).known() ? node->dim(r).value() : 1; + if (node->newShape()->rank() > 0) + { + output_shape.rank(node->newShape()->rank()); + + for (uint32_t axis = 0; axis < output_shape.rank(); ++axis) + { + output_shape.dim(axis) = node->newShape()->dim(axis); + if (node->newShape()->dim(axis) < 0) + { + output_shape.dim(axis).unset(); + } + } + } + else + { + output_shape = circle_shape(node); + } + } + else + { + // Check if `shape` is CircleNode + auto node_shape = dynamic_cast(node->shape()); + if (node_shape != nullptr) + { + output_shape.rank(node_shape->dim(0).value()); + + for (uint32_t axis = 0; axis < output_shape.rank(); ++axis) + { + output_shape.dim(axis).unset(); + } + + should_infer = false; + } } - shape_by_input = shape; - } - } - - loco::TensorShape shape_by_attr; - { - shape_by_attr.rank(node->newShape()->rank()); - - for (uint32_t axis = 0; axis < shape_by_attr.rank(); ++axis) - { - shape_by_attr.dim(axis) = node->newShape()->dim(axis); } } - if (!(shape_by_input == shape_by_attr)) - { - INFO(l) << "CircleReshape: Two new shape information mismatched : " << std::endl; - INFO(l) << " shape_by_input : " << shape_by_input << std::endl; - INFO(l) << " shape_by_attr : " << shape_by_attr << std::endl; - } - - loco::TensorShape output_shape = shape_by_input; - - // One of the dimensions can have special value -1, meaning its actual value should be inferred. const auto input = loco::must_cast(node->tensor()); const auto input_shape = circle_shape(input); uint32_t input_element_count = 1; - uint32_t output_element_count = 1; - uint32_t unknown_dim_index = UINT32_MAX; - for (uint32_t i = 0; i < input_shape.rank(); ++i) - input_element_count *= (input_shape.dim(i).known() ? input_shape.dim(i).value() : 1); - for (uint32_t dim_index = 0; dim_index < output_shape.rank(); ++dim_index) + for (uint32_t axis = 0; axis < input_shape.rank(); ++axis) { - const uint32_t dim_value = output_shape.dim(dim_index).value(); - if (static_cast(dim_value) == -1) + if (input_shape.dim(axis).known()) { - LUCI_ASSERT(unknown_dim_index == UINT32_MAX, "More than one unknown dimension"); - unknown_dim_index = dim_index; + input_element_count *= input_shape.dim(axis).value(); } else { - output_element_count *= dim_value; + should_infer = false; + break; } } - if (unknown_dim_index != UINT32_MAX) + + if (should_infer) { - output_shape.dim(unknown_dim_index) = input_element_count / output_element_count; + uint32_t output_element_count = 1; + uint32_t unknown_dim_index = UINT32_MAX; + for (uint32_t dim_index = 0; dim_index < output_shape.rank(); ++dim_index) + { + if (output_shape.dim(dim_index).known() == false) + { + LUCI_ASSERT(unknown_dim_index == UINT32_MAX, "More than one unknown dimension"); + unknown_dim_index = dim_index; + } + else + { + const uint32_t dim_value = output_shape.dim(dim_index).value(); + output_element_count *= dim_value; + } + } + if (unknown_dim_index != UINT32_MAX) + { + output_shape.dim(unknown_dim_index) = input_element_count / output_element_count; + } } return output_shape; diff --git a/compiler/luci/service/src/Nodes/CircleReshape.test.cpp b/compiler/luci/service/src/Nodes/CircleReshape.test.cpp index a6ae6735500..4a0467ec423 100644 --- a/compiler/luci/service/src/Nodes/CircleReshape.test.cpp +++ b/compiler/luci/service/src/Nodes/CircleReshape.test.cpp @@ -39,25 +39,25 @@ TEST(CloneNodeTest, clone_Reshape) ASSERT_EQ(node_reshape->newShape()->dim(1), cloned_reshape->newShape()->dim(1)); } -TEST(ShapeRuleTest, reshape_by_input_const_static) +TEST(ShapeRuleTest, reshape_by_circle_const) { auto g = loco::make_graph(); auto node_reshape = g->nodes()->create(); auto tensor_input = g->nodes()->create(); - auto shape_by_input = g->nodes()->create(); + auto shape_input = g->nodes()->create(); tensor_input->dtype(loco::DataType::S32); tensor_input->shape({2, 3, 4}); tensor_input->shape_status(luci::ShapeStatus::VALID); - shape_by_input->dtype(loco::DataType::S32); - shape_by_input->size(2); - shape_by_input->at(0) = 6; - shape_by_input->at(1) = 4; - shape_by_input->shape_status(luci::ShapeStatus::VALID); + shape_input->dtype(loco::DataType::S32); + shape_input->size(2); + shape_input->at(0) = -1; + shape_input->at(1) = 4; + shape_input->shape_status(luci::ShapeStatus::VALID); node_reshape->tensor(tensor_input); - node_reshape->shape(shape_by_input); + node_reshape->shape(shape_input); loco::TensorShape output_shape; luci::sinf::Rule shape_inf_rule; @@ -71,25 +71,25 @@ TEST(ShapeRuleTest, reshape_by_input_const_static) ASSERT_EQ(4, output_shape.dim(1).value()); } -TEST(ShapeRuleTest, reshape_by_input_const_dynamic) +TEST(ShapeRuleTest, reshape_by_circle_dummy) { auto g = loco::make_graph(); auto node_reshape = g->nodes()->create(); auto tensor_input = g->nodes()->create(); - auto shape_by_input = g->nodes()->create(); + auto shape_input = g->nodes()->create(); tensor_input->dtype(loco::DataType::S32); tensor_input->shape({2, 3, 4}); tensor_input->shape_status(luci::ShapeStatus::VALID); - shape_by_input->dtype(loco::DataType::S32); - shape_by_input->size(2); - shape_by_input->at(0) = -1; - shape_by_input->at(1) = 4; - shape_by_input->shape_status(luci::ShapeStatus::VALID); + shape_input->dtype(loco::DataType::S32); + shape_input->shape_status(luci::ShapeStatus::VALID); node_reshape->tensor(tensor_input); - node_reshape->shape(shape_by_input); + node_reshape->shape(shape_input); + node_reshape->newShape()->rank(2); + node_reshape->newShape()->dim(0) = -1; + node_reshape->newShape()->dim(1) = 4; loco::TensorShape output_shape; luci::sinf::Rule shape_inf_rule; @@ -102,3 +102,83 @@ TEST(ShapeRuleTest, reshape_by_input_const_dynamic) ASSERT_EQ(6, output_shape.dim(0).value()); ASSERT_EQ(4, output_shape.dim(1).value()); } + +TEST(ShapeRuleTest, reshape_by_circle_node) +{ + auto g = loco::make_graph(); + auto node_reshape = g->nodes()->create(); + auto tensor_input = g->nodes()->create(); + auto shape_input = g->nodes()->create(); + + tensor_input->dtype(loco::DataType::S32); + tensor_input->shape({2, 3, 4}); + tensor_input->shape_status(luci::ShapeStatus::VALID); + + shape_input->dtype(loco::DataType::S32); + shape_input->shape({2}); + shape_input->shape_status(luci::ShapeStatus::VALID); + + node_reshape->tensor(tensor_input); + node_reshape->shape(shape_input); + + loco::TensorShape output_shape; + luci::sinf::Rule shape_inf_rule; + + ASSERT_TRUE(shape_inf_rule.infer(node_reshape, output_shape)); + + ASSERT_EQ(2, output_shape.rank()); + ASSERT_FALSE(output_shape.dim(0).known()); + ASSERT_FALSE(output_shape.dim(1).known()); +} + +TEST(ShapeRuleTest, reshape_input_tensor_undefined_NEG) +{ + auto g = loco::make_graph(); + auto node_reshape = g->nodes()->create(); + auto tensor_input = g->nodes()->create(); + auto shape_by_input = g->nodes()->create(); + + tensor_input->dtype(loco::DataType::S32); + tensor_input->shape({2, 3, 4}); + tensor_input->shape_status(luci::ShapeStatus::UNDEFINED); + + shape_by_input->dtype(loco::DataType::S32); + shape_by_input->size(2); + shape_by_input->at(0) = 6; + shape_by_input->at(1) = 4; + shape_by_input->shape_status(luci::ShapeStatus::VALID); + + node_reshape->tensor(tensor_input); + node_reshape->shape(shape_by_input); + + loco::TensorShape output_shape; + luci::sinf::Rule shape_inf_rule; + + ASSERT_FALSE(shape_inf_rule.infer(node_reshape, output_shape)); +} + +TEST(ShapeRuleTest, reshape_input_shape_undefined_NEG) +{ + auto g = loco::make_graph(); + auto node_reshape = g->nodes()->create(); + auto tensor_input = g->nodes()->create(); + auto shape_by_input = g->nodes()->create(); + + tensor_input->dtype(loco::DataType::S32); + tensor_input->shape({2, 3, 4}); + tensor_input->shape_status(luci::ShapeStatus::VALID); + + shape_by_input->dtype(loco::DataType::S32); + shape_by_input->size(2); + shape_by_input->at(0) = 6; + shape_by_input->at(1) = 4; + shape_by_input->shape_status(luci::ShapeStatus::UNDEFINED); + + node_reshape->tensor(tensor_input); + node_reshape->shape(shape_by_input); + + loco::TensorShape output_shape; + luci::sinf::Rule shape_inf_rule; + + ASSERT_FALSE(shape_inf_rule.infer(node_reshape, output_shape)); +}