Skip to content

Commit

Permalink
Transform SCF IndexSwitch to nested If-Else (#7670)
Browse files Browse the repository at this point in the history
  • Loading branch information
jiahanxie353 authored Nov 25, 2024
1 parent ef72460 commit d379a35
Show file tree
Hide file tree
Showing 6 changed files with 277 additions and 0 deletions.
1 change: 1 addition & 0 deletions include/circt/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ std::unique_ptr<mlir::Pass> createInsertMergeBlocksPass();
std::unique_ptr<mlir::Pass> createPrintOpCountPass();
std::unique_ptr<mlir::Pass>
createMemoryBankingPass(std::optional<unsigned> bankingFactor = std::nullopt);
std::unique_ptr<mlir::Pass> createIndexSwitchToIfPass();

//===----------------------------------------------------------------------===//
// Utility functions.
Expand Down
31 changes: 31 additions & 0 deletions include/circt/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -133,4 +133,35 @@ def MemoryBanking : Pass<"memory-banking", "::mlir::func::FuncOp"> {
let dependentDialects = ["mlir::memref::MemRefDialect, mlir::scf::SCFDialect, mlir::affine::AffineDialect"];
}

def IndexSwitchToIf : Pass<"switch-to-if", "::mlir::ModuleOp"> {
let summary = "Index switch to if";
let description = [{
Transform `scf.index_switch` to a series of `scf.if` operations.
This is necessary for dialects that don't support switch statements, e.g., Calyx.
An example:
```
%0 = scf.index_switch %cond -> i32
case 0 { ... }
case 1 { ... }
...

=>

%c0 = arith.cmpi eq %0, 0
%c1 = arith.cmpi eq %0, 1
%0 = scf.if %c0 {
...
} else {
%1 = scf.if %c1 {
...
} else {
...
}
}
```
}];
let constructor = "circt::createIndexSwitchToIfPass()";
let dependentDialects = ["mlir::scf::SCFDialect"];
}

#endif // CIRCT_TRANSFORMS_PASSES
68 changes: 68 additions & 0 deletions lib/Conversion/SCFToCalyx/SCFToCalyx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1625,6 +1625,70 @@ class BuildParGroups : public calyx::FuncOpPartialLoweringPattern {
}
};

class BuildSwitchGroups : public calyx::FuncOpPartialLoweringPattern {
using FuncOpPartialLoweringPattern::FuncOpPartialLoweringPattern;

LogicalResult
partiallyLowerFuncToComp(FuncOp funcOp,
PatternRewriter &rewriter) const override {
LogicalResult res = success();
funcOp.walk([&](Operation *op) {
if (!isa<scf::IndexSwitchOp>(op))
return WalkResult::advance();

auto switchOp = cast<scf::IndexSwitchOp>(op);
auto loc = switchOp.getLoc();

Region &defaultRegion = switchOp.getDefaultRegion();
Operation *yieldOp = defaultRegion.front().getTerminator();
Value defaultResult = yieldOp->getOperand(0);

Value finalResult = defaultResult;
scf::IfOp prevIfOp = nullptr;

rewriter.setInsertionPointAfter(switchOp);
for (size_t i = 0; i < switchOp.getCases().size(); i++) {
auto caseValueInt = switchOp.getCases()[i];
if (prevIfOp)
rewriter.setInsertionPointToStart(&prevIfOp.getElseRegion().front());

Value caseValue = rewriter.create<ConstantIndexOp>(loc, caseValueInt);
Value cond = rewriter.create<CmpIOp>(
loc, CmpIPredicate::eq, *switchOp.getODSOperands(0).begin(),
caseValue);

auto ifOp = rewriter.create<scf::IfOp>(loc, switchOp.getResultTypes(),
cond, /*hasElseRegion=*/true);

Region &caseRegion = switchOp.getCaseRegions()[i];
IRMapping mapping;
Block &emptyThenBlock = ifOp.getThenRegion().front();
emptyThenBlock.erase();
caseRegion.cloneInto(&ifOp.getThenRegion(), mapping);

if (i == switchOp.getCases().size() - 1) {
rewriter.setInsertionPointToEnd(&ifOp.getElseRegion().front());
rewriter.create<scf::YieldOp>(loc, defaultResult);
}

if (prevIfOp) {
rewriter.setInsertionPointToEnd(&prevIfOp.getElseRegion().front());
rewriter.create<scf::YieldOp>(loc, ifOp.getResult(0));
}

if (i == 0)
finalResult = ifOp.getResult(0);
prevIfOp = ifOp;
}

rewriter.replaceOp(switchOp, finalResult);

return WalkResult::advance();
});
return res;
}
};

/// Builds a control schedule by traversing the CFG of the function and
/// associating this with the previously created groups.
/// For simplicity, the generated control flow is expanded for all possible
Expand Down Expand Up @@ -2401,6 +2465,10 @@ void SCFToCalyxPass::runOnOperation() {
addOncePattern<BuildParGroups>(loweringPatterns, patternState, funcMap,
*loweringState);

/// This pattern converts all scf.IndexSwitchOps to nested if-elses.
addOncePattern<BuildSwitchGroups>(loweringPatterns, patternState, funcMap,
*loweringState);

/// This pattern converts all index typed values to an i32 integer.
addOncePattern<calyx::ConvertIndexTypes>(loweringPatterns, patternState,
funcMap, *loweringState);
Expand Down
2 changes: 2 additions & 0 deletions lib/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ add_circt_library(CIRCTTransforms
InsertMergeBlocks.cpp
PrintOpCount.cpp
MemoryBanking.cpp
IndexSwitchToIf.cpp

ADDITIONAL_HEADER_DIRS
${CIRCT_MAIN_INCLUDE_DIR}/circt/Transforms
Expand All @@ -19,6 +20,7 @@ add_circt_library(CIRCTTransforms
MLIRFuncDialect
MLIRIR
MLIRMemRefDialect
MLIRSCFDialect
MLIRSupport
MLIRTransforms
MLIRAffineDialect
Expand Down
118 changes: 118 additions & 0 deletions lib/Transforms/IndexSwitchToIf.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
//===- IndexSwitchToIf.cpp - Index switch to if-else pass ---*-C++-*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Contains the definitions of the SCF IndexSwitch to If-Else pass.
//
//===----------------------------------------------------------------------===//

#include "circt/Transforms/Passes.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"

namespace circt {
#define GEN_PASS_DEF_INDEXSWITCHTOIF
#include "circt/Transforms/Passes.h.inc"
} // namespace circt

using namespace mlir;
using namespace circt;

struct SwitchToIfConversion : public OpConversionPattern<scf::IndexSwitchOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(scf::IndexSwitchOp switchOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = switchOp.getLoc();

Region &defaultRegion = switchOp.getDefaultRegion();

Value finalResult;
scf::IfOp prevIfOp = nullptr;

rewriter.setInsertionPointAfter(switchOp);
auto switchCases = switchOp.getCases();
for (size_t i = 0; i < switchCases.size(); i++) {
auto caseValueInt = switchCases[i];
if (prevIfOp)
rewriter.setInsertionPointToStart(&prevIfOp.getElseRegion().front());

Value caseValue =
rewriter.create<arith::ConstantIndexOp>(loc, caseValueInt);
Value cond = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, switchOp.getOperand(), caseValue);

auto ifOp = rewriter.create<scf::IfOp>(loc, switchOp.getResultTypes(),
cond, /*hasElseRegion=*/true);

Region &caseRegion = switchOp.getCaseRegions()[i];
IRMapping mapping;
Block &emptyThenBlock = ifOp.getThenRegion().front();
emptyThenBlock.erase();
caseRegion.cloneInto(&ifOp.getThenRegion(), mapping);

if (i + 1 == switchCases.size()) {
rewriter.setInsertionPointToEnd(&ifOp.getElseRegion().front());
Block &elseBlock = ifOp.getElseRegion().front();
elseBlock.erase();
defaultRegion.cloneInto(&ifOp.getElseRegion(), mapping);
}

if (prevIfOp) {
rewriter.setInsertionPointToEnd(&prevIfOp.getElseRegion().front());
rewriter.create<scf::YieldOp>(loc, ifOp.getResult(0));
}

if (i == 0)
finalResult = ifOp.getResult(0);
prevIfOp = ifOp;
}

rewriter.replaceOp(switchOp, finalResult);

return success();
}
};

namespace {

struct IndexSwitchToIfPass
: public circt::impl::IndexSwitchToIfBase<IndexSwitchToIfPass> {
public:
void runOnOperation() override {
auto *ctx = &getContext();
RewritePatternSet patterns(ctx);
ConversionTarget target(*ctx);

target.addLegalDialect<scf::SCFDialect>();
target.addLegalDialect<arith::ArithDialect>();
target.addLegalOp<ModuleOp, func::FuncOp, func::ReturnOp>();
target.addIllegalOp<scf::IndexSwitchOp>();

patterns.add<SwitchToIfConversion>(ctx);

if (applyPartialConversion(getOperation(), target, std::move(patterns))
.failed()) {
signalPassFailure();
return;
}
}
};

} // namespace

namespace circt {
std::unique_ptr<mlir::Pass> createIndexSwitchToIfPass() {
return std::make_unique<IndexSwitchToIfPass>();
}
} // namespace circt
57 changes: 57 additions & 0 deletions test/Transforms/switch-to-if.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
// RUN: circt-opt -split-input-file --switch-to-if %s | FileCheck %s

// CHECK-LABEL: func.func @example(
// CHECK-SAME: %[[VAL_0:.*]]: index) -> i32 {
// CHECK: %[[VAL_1:.*]] = arith.constant 1 : index
// CHECK: %[[VAL_2:.*]] = arith.addi %[[VAL_1]], %[[VAL_0]] : index
// CHECK: %[[VAL_3:.*]] = arith.constant 2 : index
// CHECK: %[[VAL_4:.*]] = arith.cmpi eq, %[[VAL_2]], %[[VAL_3]] : index
// CHECK: %[[VAL_5:.*]] = scf.if %[[VAL_4]] -> (i32) {
// CHECK: %[[VAL_6:.*]] = arith.constant 10 : i32
// CHECK: scf.yield %[[VAL_6]] : i32
// CHECK: } else {
// CHECK: %[[VAL_7:.*]] = arith.constant 5 : index
// CHECK: %[[VAL_8:.*]] = arith.cmpi eq, %[[VAL_2]], %[[VAL_7]] : index
// CHECK: %[[VAL_9:.*]] = scf.if %[[VAL_8]] -> (i32) {
// CHECK: %[[VAL_10:.*]] = arith.constant 20 : i32
// CHECK: scf.yield %[[VAL_10]] : i32
// CHECK: } else {
// CHECK: %[[VAL_11:.*]] = arith.constant 7 : index
// CHECK: %[[VAL_12:.*]] = arith.cmpi eq, %[[VAL_2]], %[[VAL_11]] : index
// CHECK: %[[VAL_13:.*]] = scf.if %[[VAL_12]] -> (i32) {
// CHECK: %[[VAL_14:.*]] = arith.constant 30 : i32
// CHECK: scf.yield %[[VAL_14]] : i32
// CHECK: } else {
// CHECK: %[[VAL_15:.*]] = arith.constant 50 : i32
// CHECK: scf.yield %[[VAL_15]] : i32
// CHECK: }
// CHECK: scf.yield %[[VAL_13]] : i32
// CHECK: }
// CHECK: scf.yield %[[VAL_9]] : i32
// CHECK: }
// CHECK: return %[[VAL_5]] : i32
// CHECK: }
module {
func.func @example(%arg0 : index) -> i32 {
%one = arith.constant 1 : index
%cond = arith.addi %one, %arg0 : index
%0 = scf.index_switch %cond -> i32
case 2 {
%1 = arith.constant 10 : i32
scf.yield %1 : i32
}
case 5 {
%2 = arith.constant 20 : i32
scf.yield %2 : i32
}
case 7 {
%3 = arith.constant 30 : i32
scf.yield %3 : i32
}
default {
%4 = arith.constant 50 : i32
scf.yield %4 : i32
}
return %0 : i32
}
}

0 comments on commit d379a35

Please sign in to comment.