Skip to content
This repository has been archived by the owner on Apr 23, 2021. It is now read-only.

Fix the wrong computation of dynamic strides for lowering AllocOp to LLVM #338

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 6 additions & 9 deletions include/mlir/Pass/PassOptions.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,22 +83,19 @@ class PassOptions : protected llvm::cl::SubCommand {

/// Utility methods for printing option values.
template <typename DataT>
static void printOptionValue(raw_ostream &os,
GenericOptionParser<DataT> &parser,
const DataT &value) {
static void printValue(raw_ostream &os, GenericOptionParser<DataT> &parser,
const DataT &value) {
if (Optional<StringRef> argStr = parser.findArgStrForValue(value))
os << argStr;
else
llvm_unreachable("unknown data value for option");
}
template <typename DataT, typename ParserT>
static void printOptionValue(raw_ostream &os, ParserT &parser,
const DataT &value) {
static void printValue(raw_ostream &os, ParserT &parser, const DataT &value) {
os << value;
}
template <typename ParserT>
static void printOptionValue(raw_ostream &os, ParserT &parser,
const bool &value) {
static void printValue(raw_ostream &os, ParserT &parser, const bool &value) {
os << (value ? StringRef("true") : StringRef("false"));
}

Expand Down Expand Up @@ -129,7 +126,7 @@ class PassOptions : protected llvm::cl::SubCommand {
/// Print the name and value of this option to the given stream.
void print(raw_ostream &os) final {
os << this->ArgStr << '=';
printOptionValue(os, this->getParser(), this->getValue());
printValue(os, this->getParser(), this->getValue());
}

/// Copy the value from the given option into this one.
Expand Down Expand Up @@ -172,7 +169,7 @@ class PassOptions : protected llvm::cl::SubCommand {
void print(raw_ostream &os) final {
os << this->ArgStr << '=';
auto printElementFn = [&](const DataType &value) {
printOptionValue(os, this->getParser(), value);
printValue(os, this->getParser(), value);
};
interleave(*this, os, printElementFn, ",");
}
Expand Down
9 changes: 5 additions & 4 deletions lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1054,14 +1054,15 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
// Iterate strides in reverse order, compute runningStride and strideValues.
auto nStrides = strides.size();
SmallVector<Value, 4> strideValues(nStrides, nullptr);
for (auto indexedStride : llvm::enumerate(llvm::reverse(strides))) {
int64_t index = nStrides - 1 - indexedStride.index();
for (unsigned i = 0; i < nStrides; ++i) {
int64_t index = nStrides - 1 - i;
if (strides[index] == MemRefType::getDynamicStrideOrOffset())
// Identity layout map is enforced in the match function, so we compute:
// `runningStride *= sizes[index]`
// `runningStride *= sizes[index + 1]`
runningStride =
runningStride
? rewriter.create<LLVM::MulOp>(loc, runningStride, sizes[index])
? rewriter.create<LLVM::MulOp>(loc, runningStride,
sizes[index + 1])
: createIndexConstant(rewriter, loc, 1);
else
runningStride = createIndexConstant(rewriter, loc, strides[index]);
Expand Down
6 changes: 3 additions & 3 deletions test/Conversion/StandardToLLVM/convert-memref-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ func @mixed_alloc(%arg0: index, %arg1: index) -> memref<?x42x?xf32> {
// CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
// CHECK-NEXT: llvm.insertvalue %[[off]], %{{.*}}[2] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
// CHECK-NEXT: %[[st2:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK-NEXT: %[[st1:.*]] = llvm.mul %{{.*}}, %[[c42]] : !llvm.i64
// CHECK-NEXT: %[[st0:.*]] = llvm.mul %{{.*}}, %[[M]] : !llvm.i64
// CHECK-NEXT: %[[st1:.*]] = llvm.mul %{{.*}}, %[[N]] : !llvm.i64
// CHECK-NEXT: %[[st0:.*]] = llvm.mul %{{.*}}, %[[c42]] : !llvm.i64
// CHECK-NEXT: llvm.insertvalue %[[M]], %{{.*}}[3, 0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
// CHECK-NEXT: llvm.insertvalue %[[st0]], %{{.*}}[4, 0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
// CHECK-NEXT: llvm.insertvalue %[[c42]], %{{.*}}[3, 1] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
Expand Down Expand Up @@ -142,7 +142,7 @@ func @dynamic_alloc(%arg0: index, %arg1: index) -> memref<?x?xf32> {
// CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
// CHECK-NEXT: llvm.insertvalue %[[off]], %{{.*}}[2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK-NEXT: %[[st1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK-NEXT: %[[st0:.*]] = llvm.mul %{{.*}}, %[[M]] : !llvm.i64
// CHECK-NEXT: %[[st0:.*]] = llvm.mul %{{.*}}, %[[N]] : !llvm.i64
// CHECK-NEXT: llvm.insertvalue %[[M]], %{{.*}}[3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK-NEXT: llvm.insertvalue %[[st0]], %{{.*}}[4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK-NEXT: llvm.insertvalue %[[N]], %{{.*}}[3, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
Expand Down