From d4fa1e1053aa6aa7e47f6e5e3fe4d4fd5c55c39d Mon Sep 17 00:00:00 2001 From: Artem Balyshev Date: Tue, 6 Aug 2024 15:54:40 +0300 Subject: [PATCH 1/8] [luci/pass] Introduce FuseGRU Pass This PR introduces FuseGRUPass for fusing decomposed gru pattern into single CircleGRU. ONE-DCO-1.0-Signed-off-by: Artem Balyshev ONE-DCO-1.0-Signed-off-by: Chunseok Lee --- .../luci/pass/include/luci/CircleOptimizer.h | 1 + .../luci/pass/include/luci/Pass/FuseGRUPass.h | 39 + compiler/luci/pass/src/CircleOptimizer.cpp | 3 +- compiler/luci/pass/src/FuseGRUPass.cpp | 674 ++++++++++++++++++ compiler/luci/pass/src/FuseGRUPass.test.cpp | 418 +++++++++++ 5 files changed, 1134 insertions(+), 1 deletion(-) create mode 100644 compiler/luci/pass/include/luci/Pass/FuseGRUPass.h create mode 100644 compiler/luci/pass/src/FuseGRUPass.cpp create mode 100644 compiler/luci/pass/src/FuseGRUPass.test.cpp diff --git a/compiler/luci/pass/include/luci/CircleOptimizer.h b/compiler/luci/pass/include/luci/CircleOptimizer.h index ed7cbf611df..14323639f81 100644 --- a/compiler/luci/pass/include/luci/CircleOptimizer.h +++ b/compiler/luci/pass/include/luci/CircleOptimizer.h @@ -77,6 +77,7 @@ class CircleOptimizer final FuseActivationFunction, FusePRelu, FuseGelu, + FuseGRU, FuseRsqrt, FuseRmsNorm, FuseRoPE, diff --git a/compiler/luci/pass/include/luci/Pass/FuseGRUPass.h b/compiler/luci/pass/include/luci/Pass/FuseGRUPass.h new file mode 100644 index 00000000000..152dc427d95 --- /dev/null +++ b/compiler/luci/pass/include/luci/Pass/FuseGRUPass.h @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_FUSE_GRU_PASS_H__ +#define __LUCI_FUSE_GRU_PASS_H__ + +#include + +namespace luci +{ + +/** + * @brief Class to fuse certain pattern of subgraph into CircleGRU + * + * For detailed subgraph pattern to be fused, please check its implementation. + */ +struct FuseGRUPass final : public logo::Pass +{ + const char *name(void) const final { return "luci::FuseGRUPass"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace luci + +#endif // __LUCI_FUSE_GRU_PASS_H__ diff --git a/compiler/luci/pass/src/CircleOptimizer.cpp b/compiler/luci/pass/src/CircleOptimizer.cpp index ef6a2d86a4d..ea38e460393 100644 --- a/compiler/luci/pass/src/CircleOptimizer.cpp +++ b/compiler/luci/pass/src/CircleOptimizer.cpp @@ -52,6 +52,7 @@ #include "luci/Pass/FusePreActivationBatchNormPass.h" #include "luci/Pass/FusePReluPass.h" #include "luci/Pass/FuseGeluPass.h" +#include "luci/Pass/FuseGRUPass.h" #include "luci/Pass/FuseRsqrtPass.h" #include "luci/Pass/FuseSliceWithTConvPass.h" #include "luci/Pass/FuseHorizontalFullyConnectedPass.h" @@ -398,7 +399,7 @@ void CircleOptimizer::optimize(loco::Graph *g) const option_to_pass[Options::Algorithm::XpSepActFromTransposeConv] = &createPassInstance; option_to_pass[Options::Algorithm::ForwardReshapeToUnaryOp] = &createPassInstance; option_to_pass[Options::Algorithm::ForwardTransposeOp] = &createPassInstance; - // clang-format on + // clang-format on for (auto const &m : option_to_pass) { diff --git a/compiler/luci/pass/src/FuseGRUPass.cpp b/compiler/luci/pass/src/FuseGRUPass.cpp new file mode 100644 index 00000000000..2f1f2d341ef --- /dev/null +++ b/compiler/luci/pass/src/FuseGRUPass.cpp @@ -0,0 +1,674 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/FuseGRUPass.h" +#include "helpers/NodeFiller.h" + +#include + +#include +#include + +#include + +#include + +// Helper to fuse GRU +namespace +{ + +class GRUPatternBase +{ +public: + GRUPatternBase(luci::CircleNode *candidate) { _pattern_last_node = candidate; } + + virtual ~GRUPatternBase() = default; + +public: + virtual bool matched() = 0; + +public: + luci::CircleNode *_ifm = nullptr; + luci::CircleConst *_weight_ih = nullptr; + luci::CircleConst *_bias_ih = nullptr; + luci::CircleConst *_weight_hh = nullptr; + luci::CircleConst *_bias_hh = nullptr; + + luci::CircleConst *_hidden_input = nullptr; + + luci::CircleConst *_less_const = nullptr; + + luci::CircleWhile *_while_node = nullptr; + luci::CircleWhileOut *_while_out_node = nullptr; + luci::CircleNode *_pattern_last_node = nullptr; +}; + +/** + * Below diagram shows GRU pattern to fuse. + * Note: this pattern for GRU with `return_sequences=False` + * - the below pattern will be replaced with one GRU + * Main Graph: + * [In] [CircleConst] [CircleConst] [CircleConst] [CircleConst] + * | | | | | + * V | | | | + * [CircleWhile]<----------------------------------------------------- + * | + * V + * [CircleWhileOut] + * | + * V + * [Out] + * + * Condition Graph: + * [In] [CircleConst] (scalar int32 value) + * | | + * V | + * [Less]------ + * | + * V + * [Out] + * + * Body Graph must contain: + * - 2 CircleFullyConnected nodes; + * - 3 CircleMul nodes; + * - 2 CircleLogistic nodes; + * - 2 CircleSplit nodes; + * - 6 CircleAdd nodes; + * - 1 CircleGather node; + * - 1 CircleReshape node; + * - 1 CircleSub node; + * - 1 CircleTanh node; + * - 6 CircleSplitOut nodes; + * - 5 CircleInput nodes; + * - 5 CircleOutput nodes; + * + * Body Graph: + * [In_1] [In_2]--->[Add_2 (with Const)]--->[Out_2] [In_3] + * | \ | | + * | \ [In_4]---[Gather] [Add_3 (with Const)] + * | [FullyConnected_1] | | | + * | | [Out_4] | [Out_3] + * | [Split_1] [FullyConnected_2] + * | / | \ | + * | | | \ [Split_2] + * | [Add_1] -------+----+---------------------------------/ | | + * | | | | | | + * | | | ------------------------------------[Add_4] | + * | | | | | + * | | | [Logistic_1] | + * | | | | | + * | | ----------------------------------------[Mul_2] | + * | | \ / + * | | [Add_5] + * | | | + * | [Logistic_2] [Tanh] + * \ / \ | + * [Mul_1] [Sub (with const)] | + * \ \ | + * \ ---------------------------[Mul_3] + * \ / + * \ / + * --------------------[Add_6]------------------------------ + * / \ + * / \ + * [Reshape] [Out_5] + * | + * [Out_1] + */ +class GRUPattern1 final : public GRUPatternBase +{ +public: + GRUPattern1(luci::CircleWhileOut *candidate) : GRUPatternBase(candidate) + { + assert(candidate); + _while_out_node = candidate; + } + +public: + bool matched() override; +}; + +bool GRUPattern1::matched() +{ + // 0 - check while node + _while_node = dynamic_cast(_while_out_node->input()); + if (_while_node == nullptr) + return false; + + // 1 - check condition graph: only one Less operation + // with scalar int const value + { + const auto cond_graph = _while_node->cond_graph(); + + const auto cond_nodes = loco::active_nodes(loco::output_nodes(cond_graph)); + if (cond_nodes.size() != 4) + return false; + luci::CircleLess *less_node = nullptr; + for (auto node : cond_nodes) + { + less_node = dynamic_cast(node); + if (less_node != nullptr) + break; + } + + // doesn't find Less node + if (less_node == nullptr) + return false; + + luci::CircleNode *less_input; + if (not luci::fill(&less_input, &_less_const).with_commutative_args_of(less_node)) + return false; + + if (_less_const->dtype() != loco::DataType::S32) + return false; + + if (_less_const->size() != 1) + return false; + + assert(_less_const->at(0) > 0); + } + + // 2 - Check while's input nodes + // Save hidden state input node + { + if (_while_node->input_count() != 5) + return false; + + // Save input node + _ifm = dynamic_cast(_while_node->input(4)); + if (_ifm == nullptr) + return false; + + _hidden_input = dynamic_cast(_while_node->input(3)); + if (_hidden_input == nullptr) + return false; + } + + // 3 - check body graph + { + const auto body_graph = _while_node->body_graph(); + + if (loco::input_nodes(body_graph).size() != 5) + return false; + + if (loco::output_nodes(body_graph).size() != 5) + return false; + + const auto body_nodes = loco::active_nodes(loco::output_nodes(body_graph)); + + // Save all nodes according its types + std::vector fc_nodes; + std::vector split_nodes; + std::vector logistic_nodes; + std::vector mul_nodes; + std::vector add_nodes; + std::vector sub_nodes; + std::vector reshape_nodes; + std::vector gather_nodes; + std::vector tanh_nodes; + std::vector split_out_nodes; + + for (auto node : body_nodes) + { + auto circle_node = dynamic_cast(node); + switch (circle_node->opcode()) + { + case luci::CircleOpcode::CIRCLECONST: + case luci::CircleOpcode::CIRCLEINPUT: + case luci::CircleOpcode::CIRCLEOUTPUT: + case luci::CircleOpcode::CIRCLEOUTPUTEXCLUDE: + break; + case luci::CircleOpcode::FULLY_CONNECTED: + fc_nodes.push_back(dynamic_cast(circle_node)); + break; + case luci::CircleOpcode::SPLIT: + split_nodes.push_back(dynamic_cast(circle_node)); + break; + case luci::CircleOpcode::LOGISTIC: + logistic_nodes.push_back(dynamic_cast(circle_node)); + break; + case luci::CircleOpcode::MUL: + mul_nodes.push_back(dynamic_cast(circle_node)); + break; + case luci::CircleOpcode::ADD: + add_nodes.push_back(dynamic_cast(circle_node)); + break; + case luci::CircleOpcode::SUB: + sub_nodes.push_back(dynamic_cast(circle_node)); + break; + case luci::CircleOpcode::RESHAPE: + reshape_nodes.push_back(dynamic_cast(circle_node)); + break; + case luci::CircleOpcode::GATHER: + gather_nodes.push_back(dynamic_cast(circle_node)); + break; + case luci::CircleOpcode::TANH: + tanh_nodes.push_back(dynamic_cast(circle_node)); + break; + case luci::CircleOpcode::CIRCLESPLITOUT: + split_out_nodes.push_back(dynamic_cast(circle_node)); + break; + default: + return false; + } + } + + // Check number of nodes + if (fc_nodes.size() != 2 or mul_nodes.size() != 3 or logistic_nodes.size() != 2 or + split_nodes.size() != 2 or add_nodes.size() != 6 or gather_nodes.size() != 1 or + reshape_nodes.size() != 1 or sub_nodes.size() != 1 or tanh_nodes.size() != 1 or + split_out_nodes.size() != 6) + return false; + + // Check structure + // TODO: add more checks + { + // 1 - Check Split ops + // Both has FC nodes as input + // Axis is const + for (auto node : split_nodes) + { + if (dynamic_cast(node->split_dim()) == nullptr or + dynamic_cast(node->input()) == nullptr) + return false; + } + + // 2 - Check Logistic ops + // Add is input node for both nodes + for (auto node : logistic_nodes) + { + if (dynamic_cast(node->x()) == nullptr) + return false; + } + + // 3 - Check Sub + // Const - is first input node + // Logistic - is second input node + for (auto node : sub_nodes) + { + if (dynamic_cast(node->y()) == nullptr or + dynamic_cast(node->x()) == nullptr) + return false; + } + + // 4 - Check Add + // Mul or Const or Input or Split ops can be input nodes + // Mul - 3 times as input + // Const - 2 times as input + // Input - 2 times as input + // Split - 5 times as input + { + int num_mul = 0; + int num_const = 0; + int num_input = 0; + int num_split = 0; + for (auto node : add_nodes) + { + auto x_node = dynamic_cast(node->x()); + auto y_node = dynamic_cast(node->y()); + switch (x_node->opcode()) + { + case luci::CircleOpcode::CIRCLECONST: + num_const++; + break; + case luci::CircleOpcode::CIRCLEINPUT: + num_input++; + break; + case luci::CircleOpcode::CIRCLESPLITOUT: + num_split++; + break; + case luci::CircleOpcode::MUL: + num_mul++; + break; + default: + return false; + } + + switch (y_node->opcode()) + { + case luci::CircleOpcode::CIRCLECONST: + num_const++; + break; + case luci::CircleOpcode::CIRCLEINPUT: + num_input++; + break; + case luci::CircleOpcode::CIRCLESPLITOUT: + num_split++; + break; + case luci::CircleOpcode::MUL: + num_mul++; + break; + default: + return false; + } + } + if (num_mul != 3 or num_split != 5 or num_const != 2 or num_input != 2) + return false; + } + } + + // 5 - Check Mul + // Logistic or Tanh or Sub or Input or Split ops can be input nodes + // Logistic - 2 times as input + // Tanh - 1 times as input + // Sub - 1 times as input + // Split - 1 times as input + // Input - 1 times as input + { + int num_logistic = 0; + int num_tanh = 0; + int num_sub = 0; + int num_split = 0; + int num_input = 0; + for (auto node : mul_nodes) + { + auto x_node = dynamic_cast(node->x()); + auto y_node = dynamic_cast(node->y()); + switch (x_node->opcode()) + { + case luci::CircleOpcode::LOGISTIC: + num_logistic++; + break; + case luci::CircleOpcode::CIRCLEINPUT: + num_input++; + break; + case luci::CircleOpcode::CIRCLESPLITOUT: + num_split++; + break; + case luci::CircleOpcode::TANH: + num_tanh++; + break; + case luci::CircleOpcode::SUB: + num_sub++; + break; + default: + return false; + } + + switch (y_node->opcode()) + { + case luci::CircleOpcode::LOGISTIC: + num_logistic++; + break; + case luci::CircleOpcode::CIRCLEINPUT: + num_input++; + break; + case luci::CircleOpcode::CIRCLESPLITOUT: + num_split++; + break; + case luci::CircleOpcode::TANH: + num_tanh++; + break; + case luci::CircleOpcode::SUB: + num_sub++; + break; + default: + return false; + } + } + if (num_logistic != 2 or num_tanh != 1 or num_sub != 1 or num_split != 1 or num_input != 1) + return false; + } + + // 6 - Check Gather + // Gather has two CircleInput as input + { + for (auto node : gather_nodes) + { + if (dynamic_cast(node->indices()) == nullptr) + return false; + + if (dynamic_cast(node->params()) == nullptr) + return false; + } + } + + // 7 - Check Tanh + // Input is CircleAdd + { + for (auto node : tanh_nodes) + { + if (dynamic_cast(node->x()) == nullptr) + return false; + } + } + + // Find input and hidden FC weights and biases + for (auto node : body_nodes) + { + auto *fc_node = dynamic_cast(node); + if (fc_node == nullptr) + continue; + + const auto input_node = dynamic_cast(fc_node->input()); + if (input_node == nullptr) + return false; + + // For input hidden FullyConnected - input node is CircleInput node + if (dynamic_cast(input_node) != nullptr) + { + _weight_ih = dynamic_cast(fc_node->weights()); + _bias_ih = dynamic_cast(fc_node->bias()); + } + // For hidden hidden FullyConnected - input node is CircleGather node + else if (dynamic_cast(input_node) != nullptr) + { + _weight_hh = dynamic_cast(fc_node->weights()); + _bias_hh = dynamic_cast(fc_node->bias()); + } + else + { + return false; + } + } + + if (_weight_ih == nullptr or _weight_hh == nullptr) + return false; + } + + return true; +} + +class FuseGRU final +{ +public: + FuseGRU(const GRUPatternBase *p) : _p(p) {} + +public: + void apply(void); + +private: + luci::CircleGRU *create_circle_gru(loco::Graph *graph); + +private: + const GRUPatternBase *_p; +}; + +template +void copy_values(const luci::CircleConst *node, luci::CircleConst *cloned) +{ + assert(T == node->dtype()); + assert(T == cloned->dtype()); + + const auto size = node->size(); + cloned->size(size); + for (uint32_t i = 0; i < size; i++) + cloned->at(i) = node->at(i); +} + +luci::CircleConst *clone_circleconst(luci::CircleConst *node, loco::Graph *graph) +{ + auto cloned = graph->nodes()->create(); + + if (cloned != nullptr) + { + // dtype/shape + cloned->dtype(node->dtype()); + cloned->rank(node->rank()); + + // values + switch (node->dtype()) + { + case loco::DataType::FLOAT32: + copy_values(node, cloned); + break; + + case loco::DataType::U8: + copy_values(node, cloned); + break; + + case loco::DataType::S8: + copy_values(node, cloned); + break; + + case loco::DataType::S16: + copy_values(node, cloned); + break; + + case loco::DataType::S32: + copy_values(node, cloned); + break; + + case loco::DataType::S64: + copy_values(node, cloned); + break; + + case loco::DataType::BOOL: + copy_values(node, cloned); + break; + + default: + assert(false); + } + } + + return cloned; +} + +luci::CircleGRU *FuseGRU::create_circle_gru(loco::Graph *graph) +{ + assert(graph); + + auto weight_ih_cloned = clone_circleconst(_p->_weight_ih, graph); + luci::copy_common_attributes(_p->_weight_ih, weight_ih_cloned); + + auto weight_hh_cloned = clone_circleconst(_p->_weight_hh, graph); + luci::copy_common_attributes(_p->_weight_hh, weight_hh_cloned); + + luci::CircleNode *bias_ih_cloned = nullptr; + if (_p->_bias_ih != nullptr) + { + bias_ih_cloned = clone_circleconst(_p->_bias_ih, graph); + luci::copy_common_attributes(_p->_bias_ih, bias_ih_cloned); + } + else + { + bias_ih_cloned = _p->_pattern_last_node->graph()->nodes()->create(); + } + + luci::CircleNode *bias_hh_cloned = nullptr; + if (_p->_bias_hh != nullptr) + { + bias_hh_cloned = clone_circleconst(_p->_bias_hh, graph); + luci::copy_common_attributes(_p->_bias_hh, bias_hh_cloned); + } + else + { + bias_hh_cloned = _p->_pattern_last_node->graph()->nodes()->create(); + } + + auto hidden_input_cloned = clone_circleconst(_p->_hidden_input, graph); + luci::copy_common_attributes(_p->_hidden_input, hidden_input_cloned); + + auto less_const_cloned = clone_circleconst(_p->_less_const, graph); + luci::copy_common_attributes(_p->_less_const, less_const_cloned); + + // Create and configure new CircleGRU operation. + auto circle_gru = _p->_while_node->graph()->nodes()->create(); + circle_gru->input(_p->_ifm); + circle_gru->hidden_hidden(weight_hh_cloned); + circle_gru->hidden_input(weight_ih_cloned); + circle_gru->hidden_hidden_bias(bias_hh_cloned); + circle_gru->hidden_input_bias(bias_ih_cloned); + circle_gru->state(hidden_input_cloned); + + // Note: Now support only returnSequences = false + circle_gru->returnSequences(false); + circle_gru->name("FusedCircleGRU"); + + return circle_gru; +} + +void FuseGRU::apply() +{ + auto graph = _p->_pattern_last_node->graph(); + + auto gru_out = create_circle_gru(graph); + + // set origin + std::vector> origin_vec{ + luci::get_origin(_p->_while_node), luci::get_origin(_p->_while_out_node), + luci::get_origin(_p->_weight_hh), luci::get_origin(_p->_weight_ih)}; + + luci::add_origin(gru_out, luci::composite_origin(origin_vec)); + + replace(_p->_pattern_last_node).with(gru_out); +} + +} // namespace + +namespace +{ + +bool fuse_gru(luci::CircleWhileOut *while_out_node) +{ + assert(while_out_node); + + // check first pattern + GRUPattern1 pattern(while_out_node); + if (pattern.matched()) + { + FuseGRU fuse(&pattern); + fuse.apply(); + return true; + } + + return false; +} + +} // namespace + +namespace luci +{ + +bool FuseGRUPass::run(loco::Graph *g) +{ + bool changed = false; + + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + auto while_out_node = dynamic_cast(node); + if (not while_out_node) + continue; + + if (fuse_gru(while_out_node)) + changed = true; + } + + return changed; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/FuseGRUPass.test.cpp b/compiler/luci/pass/src/FuseGRUPass.test.cpp new file mode 100644 index 00000000000..93909ea673f --- /dev/null +++ b/compiler/luci/pass/src/FuseGRUPass.test.cpp @@ -0,0 +1,418 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/FuseGRUPass.h" + +#include + +#include + +#include + +namespace +{ + +using namespace luci::test; + +class GRUGraphlet +{ +public: + GRUGraphlet() = default; + + void init(loco::Graph *g) + { + _while_node = g->nodes()->create(5, 5); + _while_out_node = g->nodes()->create(); + _hidden_node = g->nodes()->create(); + _hidden_node->dtype(loco::DataType::FLOAT32); + _time_node = g->nodes()->create(); + _time_node->dtype(loco::DataType::FLOAT32); + _state_node = g->nodes()->create(); + _state_node->dtype(loco::DataType::FLOAT32); + + _body_graph = loco::make_graph(); + _cond_graph = loco::make_graph(); + + _less_node = _cond_graph->nodes()->create(); + _less_const_node = _cond_graph->nodes()->create(); + _less_const_node->dtype(loco::DataType::S32); + _less_const_node->size(1); + _less_const_node->at(0) = 1; + + _add_node_1 = _body_graph->nodes()->create(); + _add_node_2 = _body_graph->nodes()->create(); + _add_node_3 = _body_graph->nodes()->create(); + _add_node_4 = _body_graph->nodes()->create(); + _add_node_5 = _body_graph->nodes()->create(); + _add_node_6 = _body_graph->nodes()->create(); + + _fc_node_1 = _body_graph->nodes()->create(); + _fc_node_2 = _body_graph->nodes()->create(); + _fc_weight_1 = _body_graph->nodes()->create(); + _fc_weight_1->dtype(loco::DataType::FLOAT32); + _fc_weight_2 = _body_graph->nodes()->create(); + _fc_weight_2->dtype(loco::DataType::FLOAT32); + _fc_bias_1 = _body_graph->nodes()->create(); + _fc_bias_1->dtype(loco::DataType::FLOAT32); + _fc_bias_2 = _body_graph->nodes()->create(); + _fc_bias_2->dtype(loco::DataType::FLOAT32); + + _split_const = _body_graph->nodes()->create(); + _split_const->dtype(loco::DataType::S32); + + _logistic_node_1 = _body_graph->nodes()->create(); + _logistic_node_2 = _body_graph->nodes()->create(); + + _gather_node = _body_graph->nodes()->create(); + + _mul_node_1 = _body_graph->nodes()->create(); + _mul_node_2 = _body_graph->nodes()->create(); + _mul_node_3 = _body_graph->nodes()->create(); + + _tanh_node = _body_graph->nodes()->create(); + _sub_node = _body_graph->nodes()->create(); + + _split_node_1 = _body_graph->nodes()->create(); + _split_node_2 = _body_graph->nodes()->create(); + _split_out_node_1 = _body_graph->nodes()->create(); + _split_out_node_2 = _body_graph->nodes()->create(); + _split_out_node_3 = _body_graph->nodes()->create(); + _split_out_node_4 = _body_graph->nodes()->create(); + _split_out_node_5 = _body_graph->nodes()->create(); + _split_out_node_6 = _body_graph->nodes()->create(); + + _reshape_node = _body_graph->nodes()->create(); + + auto graph_input_cond_graph = _cond_graph->inputs()->create(); + _cond_input_node = _cond_graph->nodes()->create(); + _cond_input_node->index(graph_input_cond_graph->index()); + + auto graph_output_cond_graph = _cond_graph->outputs()->create(); + _cond_output_node = _cond_graph->nodes()->create(); + _cond_output_node->index(graph_output_cond_graph->index()); + + auto graph_input_body_graph_1 = _body_graph->inputs()->create(); + _body_input_node_1 = _body_graph->nodes()->create(); + _body_input_node_1->index(graph_input_body_graph_1->index()); + + auto graph_input_body_graph_2 = _body_graph->inputs()->create(); + _body_input_node_2 = _body_graph->nodes()->create(); + _body_input_node_2->index(graph_input_body_graph_2->index()); + + auto graph_input_body_graph_3 = _body_graph->inputs()->create(); + _body_input_node_3 = _body_graph->nodes()->create(); + _body_input_node_3->index(graph_input_body_graph_3->index()); + + auto graph_input_body_graph_4 = _body_graph->inputs()->create(); + _body_input_node_4 = _body_graph->nodes()->create(); + _body_input_node_4->index(graph_input_body_graph_4->index()); + + auto graph_input_body_graph_5 = _body_graph->inputs()->create(); + _body_input_node_5 = _body_graph->nodes()->create(); + _body_input_node_5->index(graph_input_body_graph_5->index()); + + auto graph_output_body_graph_1 = _body_graph->outputs()->create(); + _body_output_node_1 = _body_graph->nodes()->create(); + _body_output_node_1->index(graph_output_body_graph_1->index()); + + auto graph_output_body_graph_2 = _body_graph->outputs()->create(); + _body_output_node_2 = _body_graph->nodes()->create(); + _body_output_node_2->index(graph_output_body_graph_2->index()); + + auto graph_output_body_graph_3 = _body_graph->outputs()->create(); + _body_output_node_3 = _body_graph->nodes()->create(); + _body_output_node_3->index(graph_output_body_graph_3->index()); + + auto graph_output_body_graph_4 = _body_graph->outputs()->create(); + _body_output_node_4 = _body_graph->nodes()->create(); + _body_output_node_4->index(graph_output_body_graph_4->index()); + + auto graph_output_body_graph_5 = _body_graph->outputs()->create(); + _body_output_node_5 = _body_graph->nodes()->create(); + _body_output_node_5->index(graph_output_body_graph_5->index()); + } + + void invalid_less_const_type() { _less_const_node->dtype(loco::DataType::S16); } + +protected: + luci::CircleWhile *_while_node; + luci::CircleWhileOut *_while_out_node; + luci::CircleConst *_time_node; + luci::CircleConst *_state_node; + luci::CircleConst *_hidden_node; + + luci::CircleInput *_cond_input_node; + luci::CircleLess *_less_node; + luci::CircleConst *_less_const_node; + luci::CircleOutput *_cond_output_node; + + luci::CircleInput *_body_input_node_1; + luci::CircleInput *_body_input_node_2; + luci::CircleInput *_body_input_node_3; + luci::CircleInput *_body_input_node_4; + luci::CircleInput *_body_input_node_5; + + luci::CircleOutput *_body_output_node_1; + luci::CircleOutput *_body_output_node_2; + luci::CircleOutput *_body_output_node_3; + luci::CircleOutput *_body_output_node_4; + luci::CircleOutput *_body_output_node_5; + + luci::CircleAdd *_add_node_1; + luci::CircleAdd *_add_node_2; + luci::CircleAdd *_add_node_3; + luci::CircleAdd *_add_node_4; + luci::CircleAdd *_add_node_5; + luci::CircleAdd *_add_node_6; + + luci::CircleMul *_mul_node_1; + luci::CircleMul *_mul_node_2; + luci::CircleMul *_mul_node_3; + + luci::CircleSub *_sub_node; + luci::CircleTanh *_tanh_node; + luci::CircleReshape *_reshape_node; + luci::CircleGather *_gather_node; + luci::CircleLogistic *_logistic_node_1; + luci::CircleLogistic *_logistic_node_2; + luci::CircleSplit *_split_node_1; + luci::CircleSplit *_split_node_2; + + luci::CircleSplitOut *_split_out_node_1; + luci::CircleSplitOut *_split_out_node_2; + luci::CircleSplitOut *_split_out_node_3; + luci::CircleSplitOut *_split_out_node_4; + luci::CircleSplitOut *_split_out_node_5; + luci::CircleSplitOut *_split_out_node_6; + + luci::CircleFullyConnected *_fc_node_1; + luci::CircleFullyConnected *_fc_node_2; + + luci::CircleConst *_split_const; + luci::CircleConst *_fc_weight_1; + luci::CircleConst *_fc_bias_1; + luci::CircleConst *_fc_weight_2; + luci::CircleConst *_fc_bias_2; + + std::unique_ptr _cond_graph; + std::unique_ptr _body_graph; +}; + +class FuseGRUTestGraph1 : public TestIOGraph, public GRUGraphlet +{ +public: + FuseGRUTestGraph1() = default; + + void init(void) + { + TestIOGraph::init({1}, {1}); + GRUGraphlet::init(g()); + + _while_node->input(0, _time_node); + _while_node->input(1, _time_node); + _while_node->input(2, _state_node); + _while_node->input(3, _hidden_node); + _while_node->input(4, input()); + + _while_out_node->input(_while_node); + output()->from(_while_out_node); + + _while_node->cond_graph(_cond_graph.get()); + _while_node->body_graph(_body_graph.get()); + + // cond graph + _less_node->x(_cond_input_node); + _less_node->y(_less_const_node); + _cond_output_node->from(_less_node); + + // body graph + _add_node_1->x(_body_input_node_1); + _add_node_1->y(_split_const); + _add_node_2->x(_body_input_node_2); + _add_node_2->y(_split_const); + + _body_output_node_5->from(_add_node_1); + _body_output_node_4->from(_add_node_2); + + _gather_node->params(_body_input_node_2); + _gather_node->indices(_body_input_node_1); + _fc_node_1->input(_body_input_node_4); + _fc_node_1->weights(_fc_weight_1); + _fc_node_1->bias(_fc_bias_1); + _fc_node_2->input(_gather_node); + _fc_node_2->weights(_fc_weight_2); + _fc_node_2->bias(_fc_bias_2); + + _split_node_1->input(_fc_node_1); + _split_node_1->split_dim(_split_const); + _split_node_2->input(_fc_node_2); + _split_node_2->split_dim(_split_const); + + _split_out_node_1->input(_split_node_1); + _split_out_node_2->input(_split_node_1); + _split_out_node_3->input(_split_node_1); + + _split_out_node_4->input(_split_node_2); + _split_out_node_5->input(_split_node_2); + _split_out_node_6->input(_split_node_2); + + _add_node_3->x(_split_out_node_1); + _add_node_3->y(_split_out_node_4); + + _add_node_4->x(_split_out_node_3); + _add_node_4->y(_split_out_node_6); + + _logistic_node_1->x(_add_node_3); + + _mul_node_1->x(_body_input_node_4); + _mul_node_1->y(_logistic_node_1); + + _sub_node->y(_logistic_node_1); + _sub_node->x(_split_const); + + _logistic_node_2->x(_add_node_4); + + _mul_node_2->x(_split_out_node_2); + _mul_node_2->y(_logistic_node_2); + + _add_node_5->x(_split_out_node_5); + _add_node_5->y(_mul_node_2); + + _tanh_node->x(_add_node_5); + + _mul_node_3->x(_sub_node); + _mul_node_3->y(_tanh_node); + + _add_node_6->x(_mul_node_1); + _add_node_6->y(_mul_node_3); + + _reshape_node->shape(_add_node_6); + + _body_output_node_3->from(_reshape_node); + } +}; + +class FuseGRUTestNegGraph : public TestIOGraph, public GRUGraphlet +{ +public: + FuseGRUTestNegGraph() = default; + + void init(void) + { + TestIOGraph::init({1}, {1}); + GRUGraphlet::init(g()); + + invalid_less_const_type(); + + _while_node->input(0, _time_node); + _while_node->input(1, _time_node); + _while_node->input(2, _state_node); + _while_node->input(3, _hidden_node); + _while_node->input(4, input()); + + _while_node->cond_graph(_cond_graph.get()); + _while_node->body_graph(_body_graph.get()); + + _while_out_node->input(_while_node); + output()->from(_while_out_node); + + // cond graph + _less_node->x(_cond_input_node); + _less_node->y(_less_const_node); + _cond_output_node->from(_less_node); + + // body graph + _add_node_1->x(_body_input_node_1); + _add_node_2->x(_body_input_node_2); + + _body_output_node_5->from(_add_node_1); + _body_output_node_4->from(_add_node_2); + + _gather_node->params(_body_input_node_2); + _fc_node_1->input(_body_input_node_4); + _fc_node_1->weights(_fc_weight_1); + _fc_node_1->bias(_fc_bias_1); + _fc_node_2->input(_gather_node); + _fc_node_2->weights(_fc_weight_2); + _fc_node_2->bias(_fc_bias_2); + + _split_node_1->input(_fc_node_1); + _split_node_2->input(_fc_node_2); + + _split_out_node_1->input(_split_node_1); + _split_out_node_2->input(_split_node_1); + _split_out_node_3->input(_split_node_1); + + _split_out_node_4->input(_split_node_2); + _split_out_node_5->input(_split_node_2); + _split_out_node_6->input(_split_node_2); + + _add_node_3->x(_split_out_node_1); + _add_node_3->y(_split_out_node_4); + + _add_node_4->x(_split_out_node_3); + _add_node_4->y(_split_out_node_6); + + _logistic_node_1->x(_add_node_3); + + _mul_node_1->x(_body_input_node_4); + _mul_node_1->y(_logistic_node_1); + + _sub_node->y(_logistic_node_1); + + _logistic_node_2->x(_add_node_4); + + _mul_node_2->x(_split_out_node_2); + _mul_node_2->y(_logistic_node_2); + + _add_node_5->x(_split_out_node_5); + _add_node_5->y(_mul_node_2); + + _tanh_node->x(_add_node_5); + + _mul_node_3->x(_sub_node); + _mul_node_3->y(_tanh_node); + + _add_node_6->x(_mul_node_1); + _add_node_6->y(_mul_node_3); + + _reshape_node->shape(_add_node_6); + + _body_output_node_3->from(_reshape_node); + } +}; + +} // namespace + +TEST(FuseGRUPassTest, fuse_pattern1) +{ + FuseGRUTestGraph1 g; + luci::FuseGRUPass pass; + + g.init(); + + EXPECT_TRUE(pass.run(g.g())); +} + +TEST(FuseGRUPassTest, fuse_NEG) +{ + FuseGRUTestNegGraph g; + luci::FuseGRUPass pass; + + g.init(); + + EXPECT_FALSE(pass.run(g.g())); +} From ec29daa017c3ce7fcb041775efb7c8031a187da2 Mon Sep 17 00:00:00 2001 From: Chunseok Lee Date: Mon, 28 Oct 2024 14:30:13 +0900 Subject: [PATCH 2/8] initialization on test class --- compiler/luci/pass/src/FuseGRUPass.test.cpp | 116 ++++++++++---------- 1 file changed, 58 insertions(+), 58 deletions(-) diff --git a/compiler/luci/pass/src/FuseGRUPass.test.cpp b/compiler/luci/pass/src/FuseGRUPass.test.cpp index 93909ea673f..bb9df366606 100644 --- a/compiler/luci/pass/src/FuseGRUPass.test.cpp +++ b/compiler/luci/pass/src/FuseGRUPass.test.cpp @@ -148,64 +148,64 @@ class GRUGraphlet void invalid_less_const_type() { _less_const_node->dtype(loco::DataType::S16); } protected: - luci::CircleWhile *_while_node; - luci::CircleWhileOut *_while_out_node; - luci::CircleConst *_time_node; - luci::CircleConst *_state_node; - luci::CircleConst *_hidden_node; - - luci::CircleInput *_cond_input_node; - luci::CircleLess *_less_node; - luci::CircleConst *_less_const_node; - luci::CircleOutput *_cond_output_node; - - luci::CircleInput *_body_input_node_1; - luci::CircleInput *_body_input_node_2; - luci::CircleInput *_body_input_node_3; - luci::CircleInput *_body_input_node_4; - luci::CircleInput *_body_input_node_5; - - luci::CircleOutput *_body_output_node_1; - luci::CircleOutput *_body_output_node_2; - luci::CircleOutput *_body_output_node_3; - luci::CircleOutput *_body_output_node_4; - luci::CircleOutput *_body_output_node_5; - - luci::CircleAdd *_add_node_1; - luci::CircleAdd *_add_node_2; - luci::CircleAdd *_add_node_3; - luci::CircleAdd *_add_node_4; - luci::CircleAdd *_add_node_5; - luci::CircleAdd *_add_node_6; - - luci::CircleMul *_mul_node_1; - luci::CircleMul *_mul_node_2; - luci::CircleMul *_mul_node_3; - - luci::CircleSub *_sub_node; - luci::CircleTanh *_tanh_node; - luci::CircleReshape *_reshape_node; - luci::CircleGather *_gather_node; - luci::CircleLogistic *_logistic_node_1; - luci::CircleLogistic *_logistic_node_2; - luci::CircleSplit *_split_node_1; - luci::CircleSplit *_split_node_2; - - luci::CircleSplitOut *_split_out_node_1; - luci::CircleSplitOut *_split_out_node_2; - luci::CircleSplitOut *_split_out_node_3; - luci::CircleSplitOut *_split_out_node_4; - luci::CircleSplitOut *_split_out_node_5; - luci::CircleSplitOut *_split_out_node_6; - - luci::CircleFullyConnected *_fc_node_1; - luci::CircleFullyConnected *_fc_node_2; - - luci::CircleConst *_split_const; - luci::CircleConst *_fc_weight_1; - luci::CircleConst *_fc_bias_1; - luci::CircleConst *_fc_weight_2; - luci::CircleConst *_fc_bias_2; + luci::CircleWhile *_while_node = nullptr; + luci::CircleWhileOut *_while_out_node = nullptr; + luci::CircleConst *_time_node = nullptr; + luci::CircleConst *_state_node = nullptr; + luci::CircleConst *_hidden_node = nullptr; + + luci::CircleInput *_cond_input_node = nullptr; + luci::CircleLess *_less_node = nullptr; + luci::CircleConst *_less_const_node = nullptr; + luci::CircleOutput *_cond_output_node = nullptr; + + luci::CircleInput *_body_input_node_1 = nullptr; + luci::CircleInput *_body_input_node_2 = nullptr; + luci::CircleInput *_body_input_node_3 = nullptr; + luci::CircleInput *_body_input_node_4 = nullptr; + luci::CircleInput *_body_input_node_5 = nullptr; + + luci::CircleOutput *_body_output_node_1 = nullptr; + luci::CircleOutput *_body_output_node_2 = nullptr; + luci::CircleOutput *_body_output_node_3 = nullptr; + luci::CircleOutput *_body_output_node_4 = nullptr; + luci::CircleOutput *_body_output_node_5 = nullptr; + + luci::CircleAdd *_add_node_1 = nullptr; + luci::CircleAdd *_add_node_2 = nullptr; + luci::CircleAdd *_add_node_3 = nullptr; + luci::CircleAdd *_add_node_4 = nullptr; + luci::CircleAdd *_add_node_5 = nullptr; + luci::CircleAdd *_add_node_6 = nullptr; + + luci::CircleMul *_mul_node_1 = nullptr; + luci::CircleMul *_mul_node_2 = nullptr; + luci::CircleMul *_mul_node_3 = nullptr; + + luci::CircleSub *_sub_node = nullptr; + luci::CircleTanh *_tanh_node = nullptr; + luci::CircleReshape *_reshape_node = nullptr; + luci::CircleGather *_gather_node = nullptr; + luci::CircleLogistic *_logistic_node_1 = nullptr; + luci::CircleLogistic *_logistic_node_2 = nullptr; + luci::CircleSplit *_split_node_1 = nullptr; + luci::CircleSplit *_split_node_2 = nullptr; + + luci::CircleSplitOut *_split_out_node_1 = nullptr; + luci::CircleSplitOut *_split_out_node_2 = nullptr; + luci::CircleSplitOut *_split_out_node_3 = nullptr; + luci::CircleSplitOut *_split_out_node_4 = nullptr; + luci::CircleSplitOut *_split_out_node_5 = nullptr; + luci::CircleSplitOut *_split_out_node_6 = nullptr; + + luci::CircleFullyConnected *_fc_node_1 = nullptr; + luci::CircleFullyConnected *_fc_node_2 = nullptr; + + luci::CircleConst *_split_const = nullptr; + luci::CircleConst *_fc_weight_1 = nullptr; + luci::CircleConst *_fc_bias_1 = nullptr; + luci::CircleConst *_fc_weight_2 = nullptr; + luci::CircleConst *_fc_bias_2 = nullptr; std::unique_ptr _cond_graph; std::unique_ptr _body_graph; From de1f543d4ec7b401641392b0b1c342442f2b7578 Mon Sep 17 00:00:00 2001 From: chunseoklee Date: Mon, 28 Oct 2024 14:34:03 +0900 Subject: [PATCH 3/8] Update compiler/luci/pass/src/FuseGRUPass.cpp Co-authored-by: SaeHie Park --- compiler/luci/pass/src/FuseGRUPass.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compiler/luci/pass/src/FuseGRUPass.cpp b/compiler/luci/pass/src/FuseGRUPass.cpp index 2f1f2d341ef..ab98afdee3c 100644 --- a/compiler/luci/pass/src/FuseGRUPass.cpp +++ b/compiler/luci/pass/src/FuseGRUPass.cpp @@ -104,7 +104,7 @@ class GRUPatternBase * | [Split_1] [FullyConnected_2] * | / | \ | * | | | \ [Split_2] - * | [Add_1] -------+----+---------------------------------/ | | + * | [Add_1] ----------------------------------------------/ | | * | | | | | | * | | | ------------------------------------[Add_4] | * | | | | | From 604fb70d4a18b86256fb750a5147bab1cf244d5b Mon Sep 17 00:00:00 2001 From: Chunseok Lee Date: Mon, 28 Oct 2024 19:16:06 +0900 Subject: [PATCH 4/8] throw instead of assert --- compiler/luci/pass/src/FuseGRUPass.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compiler/luci/pass/src/FuseGRUPass.cpp b/compiler/luci/pass/src/FuseGRUPass.cpp index ab98afdee3c..c6f0d58c8fd 100644 --- a/compiler/luci/pass/src/FuseGRUPass.cpp +++ b/compiler/luci/pass/src/FuseGRUPass.cpp @@ -551,7 +551,7 @@ luci::CircleConst *clone_circleconst(luci::CircleConst *node, loco::Graph *graph break; default: - assert(false); + throw std::runtime_error("Unsupported data type"); } } From fada0d26fea112ac190cc4dc6f7e1f1367d9cef0 Mon Sep 17 00:00:00 2001 From: chunseoklee Date: Tue, 29 Oct 2024 09:48:04 +0900 Subject: [PATCH 5/8] Update compiler/luci/pass/src/FuseGRUPass.cpp Co-authored-by: SaeHie Park --- compiler/luci/pass/src/FuseGRUPass.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compiler/luci/pass/src/FuseGRUPass.cpp b/compiler/luci/pass/src/FuseGRUPass.cpp index c6f0d58c8fd..53421804540 100644 --- a/compiler/luci/pass/src/FuseGRUPass.cpp +++ b/compiler/luci/pass/src/FuseGRUPass.cpp @@ -551,7 +551,7 @@ luci::CircleConst *clone_circleconst(luci::CircleConst *node, loco::Graph *graph break; default: - throw std::runtime_error("Unsupported data type"); + throw std::runtime_error("FuseGRU: Unsupported data type"); } } From 338a8780d763f42bac97789b509d608b732216d9 Mon Sep 17 00:00:00 2001 From: chunseoklee Date: Tue, 29 Oct 2024 13:03:20 +0900 Subject: [PATCH 6/8] Apply suggestions(must_cast) Co-authored-by: Hyukjin Jeong --- compiler/luci/pass/src/FuseGRUPass.cpp | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/compiler/luci/pass/src/FuseGRUPass.cpp b/compiler/luci/pass/src/FuseGRUPass.cpp index 53421804540..a66f9aa9950 100644 --- a/compiler/luci/pass/src/FuseGRUPass.cpp +++ b/compiler/luci/pass/src/FuseGRUPass.cpp @@ -223,7 +223,7 @@ bool GRUPattern1::matched() for (auto node : body_nodes) { - auto circle_node = dynamic_cast(node); + auto circle_node = loco::must_cast(node); switch (circle_node->opcode()) { case luci::CircleOpcode::CIRCLECONST: @@ -232,34 +232,34 @@ bool GRUPattern1::matched() case luci::CircleOpcode::CIRCLEOUTPUTEXCLUDE: break; case luci::CircleOpcode::FULLY_CONNECTED: - fc_nodes.push_back(dynamic_cast(circle_node)); + fc_nodes.push_back(loco::must_cast(circle_node)); break; case luci::CircleOpcode::SPLIT: - split_nodes.push_back(dynamic_cast(circle_node)); + split_nodes.push_back(loco::must_cast(circle_node)); break; case luci::CircleOpcode::LOGISTIC: - logistic_nodes.push_back(dynamic_cast(circle_node)); + logistic_nodes.push_back(loco::must_cast(circle_node)); break; case luci::CircleOpcode::MUL: - mul_nodes.push_back(dynamic_cast(circle_node)); + mul_nodes.push_back(loco::must_cast(circle_node)); break; case luci::CircleOpcode::ADD: - add_nodes.push_back(dynamic_cast(circle_node)); + add_nodes.push_back(loco::must_cast(circle_node)); break; case luci::CircleOpcode::SUB: - sub_nodes.push_back(dynamic_cast(circle_node)); + sub_nodes.push_back(loco::must_cast(circle_node)); break; case luci::CircleOpcode::RESHAPE: - reshape_nodes.push_back(dynamic_cast(circle_node)); + reshape_nodes.push_back(loco::must_cast(circle_node)); break; case luci::CircleOpcode::GATHER: - gather_nodes.push_back(dynamic_cast(circle_node)); + gather_nodes.push_back(loco::must_cast(circle_node)); break; case luci::CircleOpcode::TANH: - tanh_nodes.push_back(dynamic_cast(circle_node)); + tanh_nodes.push_back(loco::must_cast(circle_node)); break; case luci::CircleOpcode::CIRCLESPLITOUT: - split_out_nodes.push_back(dynamic_cast(circle_node)); + split_out_nodes.push_back(loco::must_cast(circle_node)); break; default: return false; From c69e2d50b7b0a75f4c34f5ea6523e6f4a84b8fd5 Mon Sep 17 00:00:00 2001 From: chunseoklee Date: Wed, 30 Oct 2024 10:48:00 +0900 Subject: [PATCH 7/8] Update compiler/luci/pass/src/FuseGRUPass.cpp Co-authored-by: Hyukjin Jeong --- compiler/luci/pass/src/FuseGRUPass.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compiler/luci/pass/src/FuseGRUPass.cpp b/compiler/luci/pass/src/FuseGRUPass.cpp index a66f9aa9950..0f95673ac7c 100644 --- a/compiler/luci/pass/src/FuseGRUPass.cpp +++ b/compiler/luci/pass/src/FuseGRUPass.cpp @@ -607,7 +607,7 @@ luci::CircleGRU *FuseGRU::create_circle_gru(loco::Graph *graph) // Note: Now support only returnSequences = false circle_gru->returnSequences(false); - circle_gru->name("FusedCircleGRU"); + circle_gru->name(_while_node->name() + "_FusedCircleGRU"); return circle_gru; } From 7ac097fbf876a32e9d4106dfa81f7881ced761a1 Mon Sep 17 00:00:00 2001 From: Chunseok Lee Date: Mon, 11 Nov 2024 13:32:56 +0900 Subject: [PATCH 8/8] revise fusegru pass --- compiler/luci/pass/src/FuseGRUPass.cpp | 462 ++++++++----------------- 1 file changed, 135 insertions(+), 327 deletions(-) diff --git a/compiler/luci/pass/src/FuseGRUPass.cpp b/compiler/luci/pass/src/FuseGRUPass.cpp index 0f95673ac7c..7164c2a3d1d 100644 --- a/compiler/luci/pass/src/FuseGRUPass.cpp +++ b/compiler/luci/pass/src/FuseGRUPass.cpp @@ -30,15 +30,17 @@ namespace { -class GRUPatternBase +class GRUPattern final { public: - GRUPatternBase(luci::CircleNode *candidate) { _pattern_last_node = candidate; } - - virtual ~GRUPatternBase() = default; + GRUPattern(luci::CircleWhileOut *candidate) + { + assert(candidate); + _while_out_node = candidate; + } + ~GRUPattern() = default; -public: - virtual bool matched() = 0; + bool matched(); public: luci::CircleNode *_ifm = nullptr; @@ -53,7 +55,27 @@ class GRUPatternBase luci::CircleWhile *_while_node = nullptr; luci::CircleWhileOut *_while_out_node = nullptr; - luci::CircleNode *_pattern_last_node = nullptr; + + luci::CircleReshape *reshape = nullptr; + luci::CircleConst *reshape_shape = nullptr; + + luci::CircleAdd *add_6 = nullptr; + luci::CircleMul *mul_1 = nullptr; + luci::CircleMul *mul_3 = nullptr; + luci::CircleSub *sub_with_const = nullptr; + luci::CircleTanh *tanh = nullptr; + luci::CircleLogistic *logistic_2 = nullptr; + luci::CircleAdd *add_5 = nullptr; + luci::CircleMul *mul_2 = nullptr; + luci::CircleAdd *add_1 = nullptr; + luci::CircleSplitOut *split_1_out = nullptr; + luci::CircleSplitOut *split_2_out = nullptr; + luci::CircleSplit *split_1 = nullptr; + luci::CircleSplit *split_2 = nullptr; + luci::CircleLogistic *logistic_1 = nullptr; + luci::CircleAdd *add_4 = nullptr; + luci::CircleFullyConnected *fc_1 = nullptr; + luci::CircleFullyConnected *fc_2 = nullptr; }; /** @@ -128,34 +150,24 @@ class GRUPatternBase * | * [Out_1] */ -class GRUPattern1 final : public GRUPatternBase -{ -public: - GRUPattern1(luci::CircleWhileOut *candidate) : GRUPatternBase(candidate) - { - assert(candidate); - _while_out_node = candidate; - } -public: - bool matched() override; -}; +#define CHECK_OR_FALSE(condition) \ + if (not(condition)) \ + return false; -bool GRUPattern1::matched() +bool GRUPattern::matched() { // 0 - check while node - _while_node = dynamic_cast(_while_out_node->input()); - if (_while_node == nullptr) - return false; + _while_node = loco::must_cast(_while_out_node->input()); + CHECK_OR_FALSE(_while_node != nullptr); - // 1 - check condition graph: only one Less operation - // with scalar int const value + // 1 - check condition graph { const auto cond_graph = _while_node->cond_graph(); const auto cond_nodes = loco::active_nodes(loco::output_nodes(cond_graph)); - if (cond_nodes.size() != 4) - return false; + CHECK_OR_FALSE(cond_nodes.size() == 4); + luci::CircleLess *less_node = nullptr; for (auto node : cond_nodes) { @@ -163,321 +175,117 @@ bool GRUPattern1::matched() if (less_node != nullptr) break; } + CHECK_OR_FALSE(less_node != nullptr); - // doesn't find Less node - if (less_node == nullptr) - return false; - - luci::CircleNode *less_input; - if (not luci::fill(&less_input, &_less_const).with_commutative_args_of(less_node)) - return false; - - if (_less_const->dtype() != loco::DataType::S32) - return false; - - if (_less_const->size() != 1) - return false; - - assert(_less_const->at(0) > 0); + luci::CircleNode *less_input = nullptr; + CHECK_OR_FALSE(luci::fill(&less_input, &_less_const).with_commutative_args_of(less_node)); + CHECK_OR_FALSE(_less_const->dtype() == loco::DataType::S32); + CHECK_OR_FALSE(_less_const->size() == 1); + CHECK_OR_FALSE(_less_const->at(0) > 0); } // 2 - Check while's input nodes // Save hidden state input node { - if (_while_node->input_count() != 5) - return false; + CHECK_OR_FALSE(_while_node->input_count() == 5); // Save input node - _ifm = dynamic_cast(_while_node->input(4)); - if (_ifm == nullptr) - return false; - - _hidden_input = dynamic_cast(_while_node->input(3)); - if (_hidden_input == nullptr) - return false; + _ifm = loco::must_cast(_while_node->input(4)); + _hidden_input = loco::must_cast(_while_node->input(3)); } // 3 - check body graph { const auto body_graph = _while_node->body_graph(); - if (loco::input_nodes(body_graph).size() != 5) - return false; + CHECK_OR_FALSE(loco::input_nodes(body_graph).size() == 5); + CHECK_OR_FALSE(loco::output_nodes(body_graph).size() == 5); - if (loco::output_nodes(body_graph).size() != 5) - return false; + /* Let's check the bottom part of the body graph + * --------------------[Add_6]------------------------------ + * / \ + * / \ + * [Reshape] [Out_5] + * | + * [Out_1] + */ const auto body_nodes = loco::active_nodes(loco::output_nodes(body_graph)); - // Save all nodes according its types - std::vector fc_nodes; - std::vector split_nodes; - std::vector logistic_nodes; - std::vector mul_nodes; - std::vector add_nodes; - std::vector sub_nodes; - std::vector reshape_nodes; - std::vector gather_nodes; - std::vector tanh_nodes; - std::vector split_out_nodes; - - for (auto node : body_nodes) - { - auto circle_node = loco::must_cast(node); - switch (circle_node->opcode()) - { - case luci::CircleOpcode::CIRCLECONST: - case luci::CircleOpcode::CIRCLEINPUT: - case luci::CircleOpcode::CIRCLEOUTPUT: - case luci::CircleOpcode::CIRCLEOUTPUTEXCLUDE: - break; - case luci::CircleOpcode::FULLY_CONNECTED: - fc_nodes.push_back(loco::must_cast(circle_node)); - break; - case luci::CircleOpcode::SPLIT: - split_nodes.push_back(loco::must_cast(circle_node)); - break; - case luci::CircleOpcode::LOGISTIC: - logistic_nodes.push_back(loco::must_cast(circle_node)); - break; - case luci::CircleOpcode::MUL: - mul_nodes.push_back(loco::must_cast(circle_node)); - break; - case luci::CircleOpcode::ADD: - add_nodes.push_back(loco::must_cast(circle_node)); - break; - case luci::CircleOpcode::SUB: - sub_nodes.push_back(loco::must_cast(circle_node)); - break; - case luci::CircleOpcode::RESHAPE: - reshape_nodes.push_back(loco::must_cast(circle_node)); - break; - case luci::CircleOpcode::GATHER: - gather_nodes.push_back(loco::must_cast(circle_node)); - break; - case luci::CircleOpcode::TANH: - tanh_nodes.push_back(loco::must_cast(circle_node)); - break; - case luci::CircleOpcode::CIRCLESPLITOUT: - split_out_nodes.push_back(loco::must_cast(circle_node)); - break; - default: - return false; - } - } - - // Check number of nodes - if (fc_nodes.size() != 2 or mul_nodes.size() != 3 or logistic_nodes.size() != 2 or - split_nodes.size() != 2 or add_nodes.size() != 6 or gather_nodes.size() != 1 or - reshape_nodes.size() != 1 or sub_nodes.size() != 1 or tanh_nodes.size() != 1 or - split_out_nodes.size() != 6) - return false; - - // Check structure - // TODO: add more checks - { - // 1 - Check Split ops - // Both has FC nodes as input - // Axis is const - for (auto node : split_nodes) - { - if (dynamic_cast(node->split_dim()) == nullptr or - dynamic_cast(node->input()) == nullptr) - return false; - } - - // 2 - Check Logistic ops - // Add is input node for both nodes - for (auto node : logistic_nodes) - { - if (dynamic_cast(node->x()) == nullptr) - return false; - } - - // 3 - Check Sub - // Const - is first input node - // Logistic - is second input node - for (auto node : sub_nodes) - { - if (dynamic_cast(node->y()) == nullptr or - dynamic_cast(node->x()) == nullptr) - return false; - } - - // 4 - Check Add - // Mul or Const or Input or Split ops can be input nodes - // Mul - 3 times as input - // Const - 2 times as input - // Input - 2 times as input - // Split - 5 times as input - { - int num_mul = 0; - int num_const = 0; - int num_input = 0; - int num_split = 0; - for (auto node : add_nodes) - { - auto x_node = dynamic_cast(node->x()); - auto y_node = dynamic_cast(node->y()); - switch (x_node->opcode()) - { - case luci::CircleOpcode::CIRCLECONST: - num_const++; - break; - case luci::CircleOpcode::CIRCLEINPUT: - num_input++; - break; - case luci::CircleOpcode::CIRCLESPLITOUT: - num_split++; - break; - case luci::CircleOpcode::MUL: - num_mul++; - break; - default: - return false; - } - - switch (y_node->opcode()) - { - case luci::CircleOpcode::CIRCLECONST: - num_const++; - break; - case luci::CircleOpcode::CIRCLEINPUT: - num_input++; - break; - case luci::CircleOpcode::CIRCLESPLITOUT: - num_split++; - break; - case luci::CircleOpcode::MUL: - num_mul++; - break; - default: - return false; - } - } - if (num_mul != 3 or num_split != 5 or num_const != 2 or num_input != 2) - return false; - } - } - - // 5 - Check Mul - // Logistic or Tanh or Sub or Input or Split ops can be input nodes - // Logistic - 2 times as input - // Tanh - 1 times as input - // Sub - 1 times as input - // Split - 1 times as input - // Input - 1 times as input - { - int num_logistic = 0; - int num_tanh = 0; - int num_sub = 0; - int num_split = 0; - int num_input = 0; - for (auto node : mul_nodes) - { - auto x_node = dynamic_cast(node->x()); - auto y_node = dynamic_cast(node->y()); - switch (x_node->opcode()) - { - case luci::CircleOpcode::LOGISTIC: - num_logistic++; - break; - case luci::CircleOpcode::CIRCLEINPUT: - num_input++; - break; - case luci::CircleOpcode::CIRCLESPLITOUT: - num_split++; - break; - case luci::CircleOpcode::TANH: - num_tanh++; - break; - case luci::CircleOpcode::SUB: - num_sub++; - break; - default: - return false; - } - - switch (y_node->opcode()) - { - case luci::CircleOpcode::LOGISTIC: - num_logistic++; - break; - case luci::CircleOpcode::CIRCLEINPUT: - num_input++; - break; - case luci::CircleOpcode::CIRCLESPLITOUT: - num_split++; - break; - case luci::CircleOpcode::TANH: - num_tanh++; - break; - case luci::CircleOpcode::SUB: - num_sub++; - break; - default: - return false; - } - } - if (num_logistic != 2 or num_tanh != 1 or num_sub != 1 or num_split != 1 or num_input != 1) - return false; - } - - // 6 - Check Gather - // Gather has two CircleInput as input - { - for (auto node : gather_nodes) - { - if (dynamic_cast(node->indices()) == nullptr) - return false; - - if (dynamic_cast(node->params()) == nullptr) - return false; - } - } - - // 7 - Check Tanh - // Input is CircleAdd + for (auto node : loco::active_nodes(loco::output_nodes(body_graph))) { - for (auto node : tanh_nodes) - { - if (dynamic_cast(node->x()) == nullptr) - return false; - } + reshape = dynamic_cast(node); + if (reshape) + break; } - - // Find input and hidden FC weights and biases - for (auto node : body_nodes) + CHECK_OR_FALSE(reshape != nullptr); + + add_6 = loco::must_cast(reshape->tensor()); + + /* Let's check the next bottom part above add_6 + * | [Logistic_2] [Tanh] + * \ / \ | + * [Mul_1] [Sub (with const)] | + * \ \ | + * \ ---------------------------[Mul_3] + * \ / + * \ / + * --------------------[Add_6]------------------------------ + */ + + CHECK_OR_FALSE(luci::fill(&mul_1, &mul_3).with_args_of(add_6)); + CHECK_OR_FALSE(luci::fill(&sub_with_const, &tanh).with_args_of(mul_3)); + + logistic_2 = loco::must_cast(sub_with_const->y()); + + /* Let's check the next bottom part above logistic_2 + * | | | \ [Split_2] + * | [Add_1] ----------------------------------------------/ | | + * | | | | | | + * | | | ------------------------------------[Add_4] | + * | | | | | + * | | | [Logistic_1] | + * | | | | | + * | | ----------------------------------------[Mul_2] | + * | | \ / + * | | [Add_5] + * | | | + * | [Logistic_2] [Tanh] + * \ / \ | + */ + add_5 = loco::must_cast(tanh->x()); + add_1 = loco::must_cast(logistic_2->x()); + CHECK_OR_FALSE(luci::fill(&split_1_out, &split_2_out).with_commutative_args_of(add_1)); + CHECK_OR_FALSE(luci::fill(&split_2_out, &mul_2).with_commutative_args_of(add_5)); + split_2 = loco::must_cast(split_2_out->input()); + CHECK_OR_FALSE(luci::fill(&split_1_out, &logistic_1).with_commutative_args_of(mul_2)); + split_1 = loco::must_cast(split_1_out->input()); + add_4 = loco::must_cast(logistic_1->x()); + CHECK_OR_FALSE(luci::fill(&split_1_out, &split_2_out).with_args_of(add_4)); + + /* Let's check the remainig top part + * [In_1] [In_2]--->[Add_2 (with Const)]--->[Out_2] [In_3] + * | \ | | + * | \ [In_4]---[Gather] [Add_3 (with Const)] + * | [FullyConnected_1] | | | + * | | [Out_4] | [Out_3] + * | [Split_1] [FullyConnected_2] + * | / | \ | + * | | | \ [Split_2] + * | [Add_1] ----------------------------------------------/ | | + */ + fc_1 = loco::must_cast(split_1->input()); + fc_2 = loco::must_cast(split_2->input()); + { - auto *fc_node = dynamic_cast(node); - if (fc_node == nullptr) - continue; - - const auto input_node = dynamic_cast(fc_node->input()); - if (input_node == nullptr) - return false; - - // For input hidden FullyConnected - input node is CircleInput node - if (dynamic_cast(input_node) != nullptr) - { - _weight_ih = dynamic_cast(fc_node->weights()); - _bias_ih = dynamic_cast(fc_node->bias()); - } - // For hidden hidden FullyConnected - input node is CircleGather node - else if (dynamic_cast(input_node) != nullptr) - { - _weight_hh = dynamic_cast(fc_node->weights()); - _bias_hh = dynamic_cast(fc_node->bias()); - } - else - { + _weight_ih = loco::must_cast(fc_1->weights()); + _bias_ih = dynamic_cast(fc_1->bias()); + _weight_hh = loco::must_cast(fc_2->weights()); + _bias_hh = dynamic_cast(fc_2->bias()); + if (_weight_ih == nullptr or _weight_hh == nullptr) return false; - } } - - if (_weight_ih == nullptr or _weight_hh == nullptr) - return false; - } + } return true; } @@ -485,7 +293,7 @@ bool GRUPattern1::matched() class FuseGRU final { public: - FuseGRU(const GRUPatternBase *p) : _p(p) {} + FuseGRU(const GRUPattern *p) : _p(p) {} public: void apply(void); @@ -494,7 +302,7 @@ class FuseGRU final luci::CircleGRU *create_circle_gru(loco::Graph *graph); private: - const GRUPatternBase *_p; + const GRUPattern *_p; }; template @@ -576,7 +384,7 @@ luci::CircleGRU *FuseGRU::create_circle_gru(loco::Graph *graph) } else { - bias_ih_cloned = _p->_pattern_last_node->graph()->nodes()->create(); + bias_ih_cloned = graph->nodes()->create(); } luci::CircleNode *bias_hh_cloned = nullptr; @@ -587,7 +395,7 @@ luci::CircleGRU *FuseGRU::create_circle_gru(loco::Graph *graph) } else { - bias_hh_cloned = _p->_pattern_last_node->graph()->nodes()->create(); + bias_hh_cloned = graph->nodes()->create(); } auto hidden_input_cloned = clone_circleconst(_p->_hidden_input, graph); @@ -597,7 +405,7 @@ luci::CircleGRU *FuseGRU::create_circle_gru(loco::Graph *graph) luci::copy_common_attributes(_p->_less_const, less_const_cloned); // Create and configure new CircleGRU operation. - auto circle_gru = _p->_while_node->graph()->nodes()->create(); + auto circle_gru = graph->nodes()->create(); circle_gru->input(_p->_ifm); circle_gru->hidden_hidden(weight_hh_cloned); circle_gru->hidden_input(weight_ih_cloned); @@ -607,14 +415,14 @@ luci::CircleGRU *FuseGRU::create_circle_gru(loco::Graph *graph) // Note: Now support only returnSequences = false circle_gru->returnSequences(false); - circle_gru->name(_while_node->name() + "_FusedCircleGRU"); + circle_gru->name(_p->_while_node->name() + "_FusedCircleGRU"); return circle_gru; } void FuseGRU::apply() { - auto graph = _p->_pattern_last_node->graph(); + auto graph = _p->_while_out_node->graph(); auto gru_out = create_circle_gru(graph); @@ -625,7 +433,7 @@ void FuseGRU::apply() luci::add_origin(gru_out, luci::composite_origin(origin_vec)); - replace(_p->_pattern_last_node).with(gru_out); + replace(_p->_while_out_node).with(gru_out); } } // namespace @@ -638,7 +446,7 @@ bool fuse_gru(luci::CircleWhileOut *while_out_node) assert(while_out_node); // check first pattern - GRUPattern1 pattern(while_out_node); + GRUPattern pattern(while_out_node); if (pattern.matched()) { FuseGRU fuse(&pattern);