diff --git a/include/circt/Transforms/Passes.h b/include/circt/Transforms/Passes.h index dadf74e4c2a6..2a13b48cfb3c 100644 --- a/include/circt/Transforms/Passes.h +++ b/include/circt/Transforms/Passes.h @@ -46,6 +46,7 @@ std::unique_ptr createInsertMergeBlocksPass(); std::unique_ptr createPrintOpCountPass(); std::unique_ptr createMemoryBankingPass(std::optional bankingFactor = std::nullopt); +std::unique_ptr createIndexSwitchToIfPass(); //===----------------------------------------------------------------------===// // Utility functions. diff --git a/include/circt/Transforms/Passes.td b/include/circt/Transforms/Passes.td index 12ef3dcf6dbd..67e25d2db4eb 100644 --- a/include/circt/Transforms/Passes.td +++ b/include/circt/Transforms/Passes.td @@ -133,4 +133,35 @@ def MemoryBanking : Pass<"memory-banking", "::mlir::func::FuncOp"> { let dependentDialects = ["mlir::memref::MemRefDialect, mlir::scf::SCFDialect, mlir::affine::AffineDialect"]; } +def IndexSwitchToIf : Pass<"switch-to-if", "::mlir::ModuleOp"> { + let summary = "Index switch to if"; + let description = [{ + Transform `scf.index_switch` to a series of `scf.if` operations. + This is necessary for dialects that don't support switch statements, e.g., Calyx. + An example: + ``` + %0 = scf.index_switch %cond -> i32 + case 0 { ... } + case 1 { ... } + ... + + => + + %c0 = arith.cmpi eq %0, 0 + %c1 = arith.cmpi eq %0, 1 + %0 = scf.if %c0 { + ... + } else { + %1 = scf.if %c1 { + ... + } else { + ... + } + } + ``` + }]; + let constructor = "circt::createIndexSwitchToIfPass()"; + let dependentDialects = ["mlir::scf::SCFDialect"]; +} + #endif // CIRCT_TRANSFORMS_PASSES diff --git a/lib/Conversion/SCFToCalyx/SCFToCalyx.cpp b/lib/Conversion/SCFToCalyx/SCFToCalyx.cpp index 2b406b2c54dc..92b58479abd6 100644 --- a/lib/Conversion/SCFToCalyx/SCFToCalyx.cpp +++ b/lib/Conversion/SCFToCalyx/SCFToCalyx.cpp @@ -1625,6 +1625,70 @@ class BuildParGroups : public calyx::FuncOpPartialLoweringPattern { } }; +class BuildSwitchGroups : public calyx::FuncOpPartialLoweringPattern { + using FuncOpPartialLoweringPattern::FuncOpPartialLoweringPattern; + + LogicalResult + partiallyLowerFuncToComp(FuncOp funcOp, + PatternRewriter &rewriter) const override { + LogicalResult res = success(); + funcOp.walk([&](Operation *op) { + if (!isa(op)) + return WalkResult::advance(); + + auto switchOp = cast(op); + auto loc = switchOp.getLoc(); + + Region &defaultRegion = switchOp.getDefaultRegion(); + Operation *yieldOp = defaultRegion.front().getTerminator(); + Value defaultResult = yieldOp->getOperand(0); + + Value finalResult = defaultResult; + scf::IfOp prevIfOp = nullptr; + + rewriter.setInsertionPointAfter(switchOp); + for (size_t i = 0; i < switchOp.getCases().size(); i++) { + auto caseValueInt = switchOp.getCases()[i]; + if (prevIfOp) + rewriter.setInsertionPointToStart(&prevIfOp.getElseRegion().front()); + + Value caseValue = rewriter.create(loc, caseValueInt); + Value cond = rewriter.create( + loc, CmpIPredicate::eq, *switchOp.getODSOperands(0).begin(), + caseValue); + + auto ifOp = rewriter.create(loc, switchOp.getResultTypes(), + cond, /*hasElseRegion=*/true); + + Region &caseRegion = switchOp.getCaseRegions()[i]; + IRMapping mapping; + Block &emptyThenBlock = ifOp.getThenRegion().front(); + emptyThenBlock.erase(); + caseRegion.cloneInto(&ifOp.getThenRegion(), mapping); + + if (i == switchOp.getCases().size() - 1) { + rewriter.setInsertionPointToEnd(&ifOp.getElseRegion().front()); + rewriter.create(loc, defaultResult); + } + + if (prevIfOp) { + rewriter.setInsertionPointToEnd(&prevIfOp.getElseRegion().front()); + rewriter.create(loc, ifOp.getResult(0)); + } + + if (i == 0) + finalResult = ifOp.getResult(0); + prevIfOp = ifOp; + } + + rewriter.replaceOp(switchOp, finalResult); + + return WalkResult::advance(); + }); + return res; + } +}; + /// Builds a control schedule by traversing the CFG of the function and /// associating this with the previously created groups. /// For simplicity, the generated control flow is expanded for all possible @@ -2401,6 +2465,10 @@ void SCFToCalyxPass::runOnOperation() { addOncePattern(loweringPatterns, patternState, funcMap, *loweringState); + /// This pattern converts all scf.IndexSwitchOps to nested if-elses. + addOncePattern(loweringPatterns, patternState, funcMap, + *loweringState); + /// This pattern converts all index typed values to an i32 integer. addOncePattern(loweringPatterns, patternState, funcMap, *loweringState); diff --git a/lib/Transforms/CMakeLists.txt b/lib/Transforms/CMakeLists.txt index 66c316e32a37..d7788dfbb61d 100644 --- a/lib/Transforms/CMakeLists.txt +++ b/lib/Transforms/CMakeLists.txt @@ -6,6 +6,7 @@ add_circt_library(CIRCTTransforms InsertMergeBlocks.cpp PrintOpCount.cpp MemoryBanking.cpp + IndexSwitchToIf.cpp ADDITIONAL_HEADER_DIRS ${CIRCT_MAIN_INCLUDE_DIR}/circt/Transforms @@ -19,6 +20,7 @@ add_circt_library(CIRCTTransforms MLIRFuncDialect MLIRIR MLIRMemRefDialect + MLIRSCFDialect MLIRSupport MLIRTransforms MLIRAffineDialect diff --git a/lib/Transforms/IndexSwitchToIf.cpp b/lib/Transforms/IndexSwitchToIf.cpp new file mode 100644 index 000000000000..8c6ffc49af8c --- /dev/null +++ b/lib/Transforms/IndexSwitchToIf.cpp @@ -0,0 +1,118 @@ +//===- IndexSwitchToIf.cpp - Index switch to if-else pass ---*-C++-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Contains the definitions of the SCF IndexSwitch to If-Else pass. +// +//===----------------------------------------------------------------------===// + +#include "circt/Transforms/Passes.h" +#include "mlir/Conversion/LLVMCommon/ConversionTarget.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace circt { +#define GEN_PASS_DEF_INDEXSWITCHTOIF +#include "circt/Transforms/Passes.h.inc" +} // namespace circt + +using namespace mlir; +using namespace circt; + +struct SwitchToIfConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(scf::IndexSwitchOp switchOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = switchOp.getLoc(); + + Region &defaultRegion = switchOp.getDefaultRegion(); + + Value finalResult; + scf::IfOp prevIfOp = nullptr; + + rewriter.setInsertionPointAfter(switchOp); + auto switchCases = switchOp.getCases(); + for (size_t i = 0; i < switchCases.size(); i++) { + auto caseValueInt = switchCases[i]; + if (prevIfOp) + rewriter.setInsertionPointToStart(&prevIfOp.getElseRegion().front()); + + Value caseValue = + rewriter.create(loc, caseValueInt); + Value cond = rewriter.create( + loc, arith::CmpIPredicate::eq, switchOp.getOperand(), caseValue); + + auto ifOp = rewriter.create(loc, switchOp.getResultTypes(), + cond, /*hasElseRegion=*/true); + + Region &caseRegion = switchOp.getCaseRegions()[i]; + IRMapping mapping; + Block &emptyThenBlock = ifOp.getThenRegion().front(); + emptyThenBlock.erase(); + caseRegion.cloneInto(&ifOp.getThenRegion(), mapping); + + if (i + 1 == switchCases.size()) { + rewriter.setInsertionPointToEnd(&ifOp.getElseRegion().front()); + Block &elseBlock = ifOp.getElseRegion().front(); + elseBlock.erase(); + defaultRegion.cloneInto(&ifOp.getElseRegion(), mapping); + } + + if (prevIfOp) { + rewriter.setInsertionPointToEnd(&prevIfOp.getElseRegion().front()); + rewriter.create(loc, ifOp.getResult(0)); + } + + if (i == 0) + finalResult = ifOp.getResult(0); + prevIfOp = ifOp; + } + + rewriter.replaceOp(switchOp, finalResult); + + return success(); + } +}; + +namespace { + +struct IndexSwitchToIfPass + : public circt::impl::IndexSwitchToIfBase { +public: + void runOnOperation() override { + auto *ctx = &getContext(); + RewritePatternSet patterns(ctx); + ConversionTarget target(*ctx); + + target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalOp(); + target.addIllegalOp(); + + patterns.add(ctx); + + if (applyPartialConversion(getOperation(), target, std::move(patterns)) + .failed()) { + signalPassFailure(); + return; + } + } +}; + +} // namespace + +namespace circt { +std::unique_ptr createIndexSwitchToIfPass() { + return std::make_unique(); +} +} // namespace circt diff --git a/test/Transforms/switch-to-if.mlir b/test/Transforms/switch-to-if.mlir new file mode 100644 index 000000000000..e0a24af752a9 --- /dev/null +++ b/test/Transforms/switch-to-if.mlir @@ -0,0 +1,57 @@ +// RUN: circt-opt -split-input-file --switch-to-if %s | FileCheck %s + +// CHECK-LABEL: func.func @example( +// CHECK-SAME: %[[VAL_0:.*]]: index) -> i32 { +// CHECK: %[[VAL_1:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_2:.*]] = arith.addi %[[VAL_1]], %[[VAL_0]] : index +// CHECK: %[[VAL_3:.*]] = arith.constant 2 : index +// CHECK: %[[VAL_4:.*]] = arith.cmpi eq, %[[VAL_2]], %[[VAL_3]] : index +// CHECK: %[[VAL_5:.*]] = scf.if %[[VAL_4]] -> (i32) { +// CHECK: %[[VAL_6:.*]] = arith.constant 10 : i32 +// CHECK: scf.yield %[[VAL_6]] : i32 +// CHECK: } else { +// CHECK: %[[VAL_7:.*]] = arith.constant 5 : index +// CHECK: %[[VAL_8:.*]] = arith.cmpi eq, %[[VAL_2]], %[[VAL_7]] : index +// CHECK: %[[VAL_9:.*]] = scf.if %[[VAL_8]] -> (i32) { +// CHECK: %[[VAL_10:.*]] = arith.constant 20 : i32 +// CHECK: scf.yield %[[VAL_10]] : i32 +// CHECK: } else { +// CHECK: %[[VAL_11:.*]] = arith.constant 7 : index +// CHECK: %[[VAL_12:.*]] = arith.cmpi eq, %[[VAL_2]], %[[VAL_11]] : index +// CHECK: %[[VAL_13:.*]] = scf.if %[[VAL_12]] -> (i32) { +// CHECK: %[[VAL_14:.*]] = arith.constant 30 : i32 +// CHECK: scf.yield %[[VAL_14]] : i32 +// CHECK: } else { +// CHECK: %[[VAL_15:.*]] = arith.constant 50 : i32 +// CHECK: scf.yield %[[VAL_15]] : i32 +// CHECK: } +// CHECK: scf.yield %[[VAL_13]] : i32 +// CHECK: } +// CHECK: scf.yield %[[VAL_9]] : i32 +// CHECK: } +// CHECK: return %[[VAL_5]] : i32 +// CHECK: } +module { + func.func @example(%arg0 : index) -> i32 { + %one = arith.constant 1 : index + %cond = arith.addi %one, %arg0 : index + %0 = scf.index_switch %cond -> i32 + case 2 { + %1 = arith.constant 10 : i32 + scf.yield %1 : i32 + } + case 5 { + %2 = arith.constant 20 : i32 + scf.yield %2 : i32 + } + case 7 { + %3 = arith.constant 30 : i32 + scf.yield %3 : i32 + } + default { + %4 = arith.constant 50 : i32 + scf.yield %4 : i32 + } + return %0 : i32 + } +}