Skip to content

Commit

Permalink
[luci/service] Support dynamic shape inference for reshape
Browse files Browse the repository at this point in the history
This commit supports dynamic shape inference for reshape operation

ONE-DCO-1.0-Signed-off-by: Jongwon Yang <[email protected]>
  • Loading branch information
jongwonyang committed Sep 12, 2024
1 parent 5149447 commit f699ab3
Show file tree
Hide file tree
Showing 3 changed files with 234 additions and 97 deletions.
1 change: 1 addition & 0 deletions compiler/common-artifacts/exclude.lst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
## TensorFlowLiteRecipes
optimize(Add_STR_000) # STRING is not supported
optimize(Add_STR_001) # STRING is not supported
optimize(Net_Gather_SparseToDense_AddV2_000) # Constant folding is not generally supported

## CircleRecipes

Expand Down
173 changes: 92 additions & 81 deletions compiler/luci/service/src/Nodes/CircleReshape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,29 +20,7 @@
#include "CircleShapeInferenceHelper.h"
#include "CircleCloneNode.h"

#include <luci/Log.h>

namespace
{

std::ostream &operator<<(std::ostream &os, const loco::TensorShape &tensor_shape)
{
os << "[";
for (uint32_t r = 0; r < tensor_shape.rank(); ++r)
{
if (r)
os << ",";

if (tensor_shape.dim(r).known())
os << tensor_shape.dim(r).value();
else
os << "?";
}
os << "]";
return os;
}

} // namespace
#include <oops/InternalExn.h>

namespace luci
{
Expand All @@ -65,93 +43,126 @@ luci::CircleNode *CloneNodeLet<CN::OPQR>::visit(const luci::CircleReshape *node)
namespace sinf
{

/**
* @note CircleReshape always has two inputs: tensor and shape.
* The shape input can be CircleConst, CircleOutputDummy, or CircleNode.
* - If the shape input is CircleConst, the shape is inferred from the constant.
* - If the shape input is CircleOutputDummy, the shape is inferred from
* the attribute if it exists. If the attribute does not exist,
* the shape is inferred from the node iteself.
* - If the shape input is CircleNode, the shape is not inferred.
*/
loco::TensorShape Algorithm::visit(const luci::CircleReshape *node)
{
LOGGER(l);

const loco::DataType S32 = loco::DataType::S32;

loco::TensorShape shape_by_input;
// CircleReshape node must have reshape/shape
if (node->shape() == nullptr)
{
LUCI_ASSERT(node->shape(), "2nd input shape() should not be nullptr");
INTERNAL_EXN("2nd input shape() should not be nullptr");
}

// Only support node's shape() is CircleConst with S32
// TODO support other node with other types
auto const_shape_node = dynamic_cast<luci::CircleConst *>(node->shape());
if (const_shape_node != nullptr)
bool should_infer = true;
loco::TensorShape output_shape;
{
// Check if reshape/shape is CircleConst
auto const_input = dynamic_cast<luci::CircleConst *>(node->shape());
if (const_input != nullptr)
{
LUCI_ASSERT(const_shape_node->dtype() == S32, "Only support int32 CircleConst");
output_shape.rank(const_input->size<S32>());

shape_by_input.rank(const_shape_node->size<S32>());

for (uint32_t axis = 0; axis < shape_by_input.rank(); ++axis)
for (uint32_t axis = 0; axis < output_shape.rank(); ++axis)
{
shape_by_input.dim(axis) = const_shape_node->at<S32>(axis);
output_shape.dim(axis) = const_input->at<S32>(axis);
if (const_input->at<S32>(axis) < 0)
{
output_shape.dim(axis).unset();
}
}
}
else
{
// We use shape from the node itself
loco::TensorShape shape;
shape.rank(node->rank());
for (uint32_t r = 0; r < node->rank(); ++r)
// Check if reshape/shape is CircleOutputDummy
auto dummy_input = dynamic_cast<luci::CircleOutputDummy *>(node->shape());
if (dummy_input != nullptr)
{
// TODO remove this copy from `use_own(node);`
// Shape inference rules in this file did not consider unknown dimension.
// If some node has unknown dimension, 0 is inserted and wrong shape
// inference was done as a result.
// To fix this, new shape inference algorithm is being implemented.
// Until new inference algorithm is fully implemented, unknown dimension
// would be represented as 1 along with TFLite expression.
shape.dim(r) = node->dim(r).known() ? node->dim(r).value() : 1;
if (node->newShape()->rank() > 0)
{
output_shape.rank(node->newShape()->rank());

for (uint32_t axis = 0; axis < output_shape.rank(); ++axis)
{
output_shape.dim(axis) = node->newShape()->dim(axis);
if (node->newShape()->dim(axis) < 0)
{
output_shape.dim(axis).unset();
}
}
}
else
{
output_shape = circle_shape(node);
}
}
else
{
// Check if reshape/shape is CircleNode
auto node_input = dynamic_cast<luci::CircleNode *>(node->shape());
if (node_input != nullptr)
{
output_shape.rank(node_input->dim(0).value());

for (uint32_t axis = 0; axis < output_shape.rank(); ++axis)
{
output_shape.dim(axis).unset();
}

should_infer = false;
}
}
shape_by_input = shape;
}
}

loco::TensorShape shape_by_attr;
{
shape_by_attr.rank(node->newShape()->rank());

for (uint32_t axis = 0; axis < shape_by_attr.rank(); ++axis)
{
shape_by_attr.dim(axis) = node->newShape()->dim(axis);
}
}

if (!(shape_by_input == shape_by_attr))
{
INFO(l) << "CircleReshape: Two new shape information mismatched : " << std::endl;
INFO(l) << " shape_by_input : " << shape_by_input << std::endl;
INFO(l) << " shape_by_attr : " << shape_by_attr << std::endl;
}

loco::TensorShape output_shape = shape_by_input;

// One of the dimensions can have special value -1, meaning its actual value should be inferred.
const auto input = loco::must_cast<luci::CircleNode *>(node->tensor());
const auto input_shape = circle_shape(input);
uint32_t input_element_count = 1;
uint32_t output_element_count = 1;
uint32_t unknown_dim_index = UINT32_MAX;
for (uint32_t i = 0; i < input_shape.rank(); ++i)
input_element_count *= (input_shape.dim(i).known() ? input_shape.dim(i).value() : 1);
for (uint32_t dim_index = 0; dim_index < output_shape.rank(); ++dim_index)
for (uint32_t axis = 0; axis < input_shape.rank(); ++axis)
{
const uint32_t dim_value = output_shape.dim(dim_index).value();
if (static_cast<int>(dim_value) == -1)
if (input_shape.dim(axis).known())
{
LUCI_ASSERT(unknown_dim_index == UINT32_MAX, "More than one unknown dimension");
unknown_dim_index = dim_index;
input_element_count *= input_shape.dim(axis).value();
}
else
{
output_element_count *= dim_value;
should_infer = false;
break;
}
}
if (unknown_dim_index != UINT32_MAX)

if (should_infer)
{
output_shape.dim(unknown_dim_index) = input_element_count / output_element_count;
uint32_t output_element_count = 1;
uint32_t unknown_dim_index = UINT32_MAX;
for (uint32_t dim_index = 0; dim_index < output_shape.rank(); ++dim_index)
{
if (output_shape.dim(dim_index).known() == false)
{
if (unknown_dim_index != UINT32_MAX)
{
INTERNAL_EXN("More than one unknown dimension");
}
unknown_dim_index = dim_index;
}
else
{
const uint32_t dim_value = output_shape.dim(dim_index).value();
output_element_count *= dim_value;
}
}
if (unknown_dim_index != UINT32_MAX)
{
output_shape.dim(unknown_dim_index) = input_element_count / output_element_count;
}
}

return output_shape;
Expand Down
Loading

0 comments on commit f699ab3

Please sign in to comment.