diff --git a/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h index 6f41fb686338..47d6785692b7 100644 --- a/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h +++ b/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h @@ -35,6 +35,7 @@ class Type; namespace mlir { +class MemRefDescriptor; class UnrankedMemRefType; namespace LLVM { @@ -56,8 +57,9 @@ class LLVMTypeConverter : public TypeConverter { /// Convert a function type. The arguments and results are converted one by /// one and results are packed into a wrapped LLVM IR structure type. `result` /// is populated with argument mapping. - LLVM::LLVMType convertFunctionSignature(FunctionType type, bool isVariadic, - SignatureConversion &result); + virtual LLVM::LLVMType convertFunctionSignature(FunctionType type, + bool isVariadic, + SignatureConversion &result); /// Convert a non-empty list of types to be returned from a function into a /// supported LLVM IR type. In particular, if more than one values is @@ -71,6 +73,20 @@ class LLVMTypeConverter : public TypeConverter { /// Returns the LLVM dialect. LLVM::LLVMDialect *getDialect() { return llvmDialect; } + /// Create a DefaultMemRefDescriptor object for 'value'. + virtual std::unique_ptr + createMemRefDescriptor(ValuePtr value); + + /// Builds IR creating an uninitialized value of the descriptor type. + virtual std::unique_ptr + buildMemRefDescriptor(OpBuilder &builder, Location loc, Type descriptorType); + /// Builds IR creating a MemRef descriptor that represents `type` and + /// populates it with static shape and stride information extracted from the + /// type. + virtual std::unique_ptr + buildStaticMemRefDescriptor(OpBuilder &builder, Location loc, MemRefType type, + ValuePtr memory); + /// Promote the LLVM struct representation of all MemRef descriptors to stack /// and use pointers to struct to avoid the complexity of the /// platform-specific C/C++ ABI lowering related to struct argument passing. @@ -90,6 +106,9 @@ class LLVMTypeConverter : public TypeConverter { llvm::Module *module; LLVM::LLVMDialect *llvmDialect; + // Extract an LLVM IR dialect type. + LLVM::LLVMType unwrap(Type type); + private: Type convertStandardType(Type type); @@ -129,9 +148,60 @@ class LLVMTypeConverter : public TypeConverter { // Get the LLVM representation of the index type based on the bitwidth of the // pointer as defined by the data layout of the module. LLVM::LLVMType getIndexType(); +}; - // Extract an LLVM IR dialect type. - LLVM::LLVMType unwrap(Type type); +// Base helper class to lower MemRef type to a descriptor in LLVM. Provides an +// abstract API to produce LLVM dialect operations that manipulate the MemRef +// descriptor. Specific MemRef descriptor implementations should inherint from +// this class and implement the API. +struct MemRefDescriptor { + + virtual Value *getValue() = 0; + + /// Builds IR extracting the allocated pointer from the descriptor. + virtual Value *allocatedPtr(OpBuilder &builder, Location loc) = 0; + /// Builds IR inserting the allocated pointer into the descriptor. + virtual void setAllocatedPtr(OpBuilder &builder, Location loc, + Value *ptr) = 0; + + /// Builds IR extracting the aligned pointer from the descriptor. + virtual Value *alignedPtr(OpBuilder &builder, Location loc) = 0; + + /// Builds IR inserting the aligned pointer into the descriptor. + virtual void setAlignedPtr(OpBuilder &builder, Location loc, Value *ptr) = 0; + + /// Builds IR extracting the offset from the descriptor. + virtual Value *offset(OpBuilder &builder, Location loc) = 0; + + /// Builds IR inserting the offset into the descriptor. + virtual void setOffset(OpBuilder &builder, Location loc, Value *offset) = 0; + + virtual void setConstantOffset(OpBuilder &builder, Location loc, + uint64_t offset) = 0; + + /// Builds IR extracting the pos-th size from the descriptor. + virtual Value *size(OpBuilder &builder, Location loc, unsigned pos) = 0; + + /// Builds IR inserting the pos-th size into the descriptor + virtual void setSize(OpBuilder &builder, Location loc, unsigned pos, + Value *size) = 0; + virtual void setConstantSize(OpBuilder &builder, Location loc, unsigned pos, + uint64_t size) = 0; + + /// Builds IR extracting the pos-th size from the descriptor. + virtual Value *stride(OpBuilder &builder, Location loc, unsigned pos) = 0; + + /// Builds IR inserting the pos-th stride into the descriptor + virtual void setStride(OpBuilder &builder, Location loc, unsigned pos, + Value *stride) = 0; + virtual void setConstantStride(OpBuilder &builder, Location loc, unsigned pos, + uint64_t stride) = 0; + + /// Returns the (LLVM) type this descriptor points to. + virtual LLVM::LLVMType getElementType() = 0; + +protected: + MemRefDescriptor() = default; }; /// Helper class to produce LLVM dialect operations extracting or inserting @@ -144,7 +214,7 @@ class StructBuilder { static StructBuilder undef(OpBuilder &builder, Location loc, Type descriptorType); - /*implicit*/ operator ValuePtr() { return value; } + ValuePtr getValue() { return value; } protected: // LLVM value @@ -158,22 +228,16 @@ class StructBuilder { /// Builds IR to set a value in the struct at position pos void setPtr(OpBuilder &builder, Location loc, unsigned pos, ValuePtr ptr); }; + /// Helper class to produce LLVM dialect operations extracting or inserting /// elements of a MemRef descriptor. Wraps a Value pointing to the descriptor. /// The Value may be null, in which case none of the operations are valid. -class MemRefDescriptor : public StructBuilder { +class DefaultMemRefDescriptor : public StructBuilder, public MemRefDescriptor { public: /// Construct a helper for the given descriptor value. - explicit MemRefDescriptor(ValuePtr descriptor); - /// Builds IR creating an `undef` value of the descriptor type. - static MemRefDescriptor undef(OpBuilder &builder, Location loc, - Type descriptorType); - /// Builds IR creating a MemRef descriptor that represents `type` and - /// populates it with static shape and stride information extracted from the - /// type. - static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc, - LLVMTypeConverter &typeConverter, - MemRefType type, ValuePtr memory); + explicit DefaultMemRefDescriptor(ValuePtr descriptor); + + ValuePtr getValue() override { return StructBuilder::getValue(); }; /// Builds IR extracting the allocated pointer from the descriptor. ValuePtr allocatedPtr(OpBuilder &builder, Location loc); diff --git a/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp index bf18ea03dab0..9813710e2168 100644 --- a/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -623,9 +623,10 @@ struct GPUFuncOpLowering : LLVMOpLowering { // and canonicalize that away later. ValuePtr attribution = gpuFuncOp.getWorkgroupAttributions()[en.index()]; auto type = attribution->getType().cast(); - auto descr = MemRefDescriptor::fromStaticShape(rewriter, loc, lowering, - type, memory); - signatureConversion.remapInput(numProperArguments + en.index(), descr); + auto descr = + lowering.buildStaticMemRefDescriptor(rewriter, loc, type, memory); + signatureConversion.remapInput(numProperArguments + en.index(), + descr->getValue()); } // Rewrite private memory attributions to alloca'ed buffers. @@ -649,10 +650,11 @@ struct GPUFuncOpLowering : LLVMOpLowering { rewriter.getI64IntegerAttr(type.getNumElements())); ValuePtr allocated = rewriter.create( gpuFuncOp.getLoc(), ptrType, numElements, /*alignment=*/0); - auto descr = MemRefDescriptor::fromStaticShape(rewriter, loc, lowering, - type, allocated); + auto descr = lowering.buildStaticMemRefDescriptor(rewriter, loc, type, + allocated); signatureConversion.remapInput( - numProperArguments + numWorkgroupAttributions + en.index(), descr); + numProperArguments + numWorkgroupAttributions + en.index(), + descr->getValue()); } } diff --git a/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp index 8b6b9fb79303..f93b342207da 100644 --- a/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp +++ b/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp @@ -117,32 +117,33 @@ namespace { /// EDSC-compatible wrapper for MemRefDescriptor. class BaseViewConversionHelper { public: - BaseViewConversionHelper(Type type) - : d(MemRefDescriptor::undef(rewriter(), loc(), type)) {} + BaseViewConversionHelper(Type type, LLVMTypeConverter &typeConverter) + : d(typeConverter.buildMemRefDescriptor(rewriter(), loc(), type)) {} - BaseViewConversionHelper(ValuePtr v) : d(v) {} + BaseViewConversionHelper(ValuePtr v, LLVMTypeConverter &typeConverter) + : d(typeConverter.createMemRefDescriptor(v)) {} /// Wrappers around MemRefDescriptor that use EDSC builder and location. - ValuePtr allocatedPtr() { return d.allocatedPtr(rewriter(), loc()); } - void setAllocatedPtr(ValuePtr v) { d.setAllocatedPtr(rewriter(), loc(), v); } - ValuePtr alignedPtr() { return d.alignedPtr(rewriter(), loc()); } - void setAlignedPtr(ValuePtr v) { d.setAlignedPtr(rewriter(), loc(), v); } - ValuePtr offset() { return d.offset(rewriter(), loc()); } - void setOffset(ValuePtr v) { d.setOffset(rewriter(), loc(), v); } - ValuePtr size(unsigned i) { return d.size(rewriter(), loc(), i); } - void setSize(unsigned i, ValuePtr v) { d.setSize(rewriter(), loc(), i, v); } - ValuePtr stride(unsigned i) { return d.stride(rewriter(), loc(), i); } + ValuePtr allocatedPtr() { return d->allocatedPtr(rewriter(), loc()); } + void setAllocatedPtr(ValuePtr v) { d->setAllocatedPtr(rewriter(), loc(), v); } + ValuePtr alignedPtr() { return d->alignedPtr(rewriter(), loc()); } + void setAlignedPtr(ValuePtr v) { d->setAlignedPtr(rewriter(), loc(), v); } + ValuePtr offset() { return d->offset(rewriter(), loc()); } + void setOffset(ValuePtr v) { d->setOffset(rewriter(), loc(), v); } + ValuePtr size(unsigned i) { return d->size(rewriter(), loc(), i); } + void setSize(unsigned i, ValuePtr v) { d->setSize(rewriter(), loc(), i, v); } + ValuePtr stride(unsigned i) { return d->stride(rewriter(), loc(), i); } void setStride(unsigned i, ValuePtr v) { - d.setStride(rewriter(), loc(), i, v); + d->setStride(rewriter(), loc(), i, v); } - operator ValuePtr() { return d; } + operator ValuePtr() { return d->getValue(); } private: OpBuilder &rewriter() { return ScopedContext::getBuilder(); } Location loc() { return ScopedContext::getLocation(); } - MemRefDescriptor d; + std::unique_ptr d; }; } // namespace @@ -190,14 +191,15 @@ class SliceOpConversion : public LLVMOpLowering { ConversionPatternRewriter &rewriter) const override { edsc::ScopedContext context(rewriter, op->getLoc()); SliceOpOperandAdaptor adaptor(operands); - BaseViewConversionHelper baseDesc(adaptor.view()); + BaseViewConversionHelper baseDesc(adaptor.view(), lowering); auto sliceOp = cast(op); auto memRefType = sliceOp.getBaseViewType(); auto int64Ty = lowering.convertType(rewriter.getIntegerType(64)) .cast(); - BaseViewConversionHelper desc(lowering.convertType(sliceOp.getViewType())); + BaseViewConversionHelper desc(lowering.convertType(sliceOp.getViewType()), + lowering); // TODO(ntv): extract sizes and emit asserts. SmallVector strides(memRefType.getRank()); @@ -282,7 +284,7 @@ class TransposeOpConversion : public LLVMOpLowering { // Initialize the common boilerplate and alloca at the top of the FuncOp. edsc::ScopedContext context(rewriter, op->getLoc()); TransposeOpOperandAdaptor adaptor(operands); - BaseViewConversionHelper baseDesc(adaptor.view()); + BaseViewConversionHelper baseDesc(adaptor.view(), lowering); auto transposeOp = cast(op); // No permutation, early exit. @@ -290,7 +292,7 @@ class TransposeOpConversion : public LLVMOpLowering { return rewriter.replaceOp(op, {baseDesc}), matchSuccess(); BaseViewConversionHelper desc( - lowering.convertType(transposeOp.getViewType())); + lowering.convertType(transposeOp.getViewType()), lowering); // Copy the base and aligned pointers from the old descriptor to the new // one. diff --git a/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp index 67b545c4ec84..292c7c7548be 100644 --- a/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -274,74 +274,36 @@ void StructBuilder::setPtr(OpBuilder &builder, Location loc, unsigned pos, builder.getI64ArrayAttr(pos)); } /*============================================================================*/ -/* MemRefDescriptor implementation */ +/* DefaultMemRefDescriptor implementation */ /*============================================================================*/ /// Construct a helper for the given descriptor value. -MemRefDescriptor::MemRefDescriptor(ValuePtr descriptor) +DefaultMemRefDescriptor::DefaultMemRefDescriptor(ValuePtr descriptor) : StructBuilder(descriptor) { assert(value != nullptr && "value cannot be null"); indexType = value->getType().cast().getStructElementType( kOffsetPosInMemRefDescriptor); } -/// Builds IR creating an `undef` value of the descriptor type. -MemRefDescriptor MemRefDescriptor::undef(OpBuilder &builder, Location loc, - Type descriptorType) { - - ValuePtr descriptor = - builder.create(loc, descriptorType.cast()); - return MemRefDescriptor(descriptor); -} - -/// Builds IR creating a MemRef descriptor that represents `type` and -/// populates it with static shape and stride information extracted from the -/// type. -MemRefDescriptor -MemRefDescriptor::fromStaticShape(OpBuilder &builder, Location loc, - LLVMTypeConverter &typeConverter, - MemRefType type, ValuePtr memory) { - assert(type.hasStaticShape() && "unexpected dynamic shape"); - assert(type.getAffineMaps().empty() && "unexpected layout map"); - - auto convertedType = typeConverter.convertType(type); - assert(convertedType && "unexpected failure in memref type conversion"); - - auto descr = MemRefDescriptor::undef(builder, loc, convertedType); - descr.setAllocatedPtr(builder, loc, memory); - descr.setAlignedPtr(builder, loc, memory); - descr.setConstantOffset(builder, loc, 0); - - // Fill in sizes and strides, in reverse order to simplify stride - // calculation. - uint64_t runningStride = 1; - for (unsigned i = type.getRank(); i > 0; --i) { - unsigned dim = i - 1; - descr.setConstantSize(builder, loc, dim, type.getDimSize(dim)); - descr.setConstantStride(builder, loc, dim, runningStride); - runningStride *= type.getDimSize(dim); - } - return descr; -} - /// Builds IR extracting the allocated pointer from the descriptor. -ValuePtr MemRefDescriptor::allocatedPtr(OpBuilder &builder, Location loc) { +ValuePtr DefaultMemRefDescriptor::allocatedPtr(OpBuilder &builder, + Location loc) { return extractPtr(builder, loc, kAllocatedPtrPosInMemRefDescriptor); } /// Builds IR inserting the allocated pointer into the descriptor. -void MemRefDescriptor::setAllocatedPtr(OpBuilder &builder, Location loc, +void DefaultMemRefDescriptor::setAllocatedPtr(OpBuilder &builder, Location loc, ValuePtr ptr) { setPtr(builder, loc, kAllocatedPtrPosInMemRefDescriptor, ptr); } /// Builds IR extracting the aligned pointer from the descriptor. -ValuePtr MemRefDescriptor::alignedPtr(OpBuilder &builder, Location loc) { +ValuePtr DefaultMemRefDescriptor::alignedPtr(OpBuilder &builder, Location loc) { return extractPtr(builder, loc, kAlignedPtrPosInMemRefDescriptor); } /// Builds IR inserting the aligned pointer into the descriptor. -void MemRefDescriptor::setAlignedPtr(OpBuilder &builder, Location loc, +void DefaultMemRefDescriptor::setAlignedPtr(OpBuilder &builder, Location loc, ValuePtr ptr) { setPtr(builder, loc, kAlignedPtrPosInMemRefDescriptor, ptr); } @@ -355,14 +317,14 @@ static ValuePtr createIndexAttrConstant(OpBuilder &builder, Location loc, } /// Builds IR extracting the offset from the descriptor. -ValuePtr MemRefDescriptor::offset(OpBuilder &builder, Location loc) { +ValuePtr DefaultMemRefDescriptor::offset(OpBuilder &builder, Location loc) { return builder.create( loc, indexType, value, builder.getI64ArrayAttr(kOffsetPosInMemRefDescriptor)); } /// Builds IR inserting the offset into the descriptor. -void MemRefDescriptor::setOffset(OpBuilder &builder, Location loc, +void DefaultMemRefDescriptor::setOffset(OpBuilder &builder, Location loc, ValuePtr offset) { value = builder.create( loc, structType, value, offset, @@ -370,22 +332,22 @@ void MemRefDescriptor::setOffset(OpBuilder &builder, Location loc, } /// Builds IR inserting the offset into the descriptor. -void MemRefDescriptor::setConstantOffset(OpBuilder &builder, Location loc, +void DefaultMemRefDescriptor::setConstantOffset(OpBuilder &builder, Location loc, uint64_t offset) { setOffset(builder, loc, createIndexAttrConstant(builder, loc, indexType, offset)); } /// Builds IR extracting the pos-th size from the descriptor. -ValuePtr MemRefDescriptor::size(OpBuilder &builder, Location loc, - unsigned pos) { +ValuePtr DefaultMemRefDescriptor::size(OpBuilder &builder, Location loc, + unsigned pos) { return builder.create( loc, indexType, value, builder.getI64ArrayAttr({kSizePosInMemRefDescriptor, pos})); } /// Builds IR inserting the pos-th size into the descriptor -void MemRefDescriptor::setSize(OpBuilder &builder, Location loc, unsigned pos, +void DefaultMemRefDescriptor::setSize(OpBuilder &builder, Location loc, unsigned pos, ValuePtr size) { value = builder.create( loc, structType, value, size, @@ -393,22 +355,22 @@ void MemRefDescriptor::setSize(OpBuilder &builder, Location loc, unsigned pos, } /// Builds IR inserting the pos-th size into the descriptor -void MemRefDescriptor::setConstantSize(OpBuilder &builder, Location loc, +void DefaultMemRefDescriptor::setConstantSize(OpBuilder &builder, Location loc, unsigned pos, uint64_t size) { setSize(builder, loc, pos, createIndexAttrConstant(builder, loc, indexType, size)); } /// Builds IR extracting the pos-th size from the descriptor. -ValuePtr MemRefDescriptor::stride(OpBuilder &builder, Location loc, - unsigned pos) { +ValuePtr DefaultMemRefDescriptor::stride(OpBuilder &builder, Location loc, + unsigned pos) { return builder.create( loc, indexType, value, builder.getI64ArrayAttr({kStridePosInMemRefDescriptor, pos})); } /// Builds IR inserting the pos-th stride into the descriptor -void MemRefDescriptor::setStride(OpBuilder &builder, Location loc, unsigned pos, +void DefaultMemRefDescriptor::setStride(OpBuilder &builder, Location loc, unsigned pos, ValuePtr stride) { value = builder.create( loc, structType, value, stride, @@ -416,13 +378,13 @@ void MemRefDescriptor::setStride(OpBuilder &builder, Location loc, unsigned pos, } /// Builds IR inserting the pos-th stride into the descriptor -void MemRefDescriptor::setConstantStride(OpBuilder &builder, Location loc, +void DefaultMemRefDescriptor::setConstantStride(OpBuilder &builder, Location loc, unsigned pos, uint64_t stride) { setStride(builder, loc, pos, createIndexAttrConstant(builder, loc, indexType, stride)); } -LLVM::LLVMType MemRefDescriptor::getElementType() { +LLVM::LLVMType DefaultMemRefDescriptor::getElementType() { return value->getType().cast().getStructElementType( kAlignedPtrPosInMemRefDescriptor); } @@ -1027,9 +989,10 @@ struct AllocOpLowering : public LLVMLegalizationPattern { // Create the MemRef descriptor. auto structType = lowering.convertType(type); - auto memRefDescriptor = MemRefDescriptor::undef(rewriter, loc, structType); + auto memRefDescriptor = + lowering.buildMemRefDescriptor(rewriter, loc, structType); // Field 1: Allocated pointer, used for malloc/free. - memRefDescriptor.setAllocatedPtr(rewriter, loc, bitcastAllocated); + memRefDescriptor->setAllocatedPtr(rewriter, loc, bitcastAllocated); // Field 2: Actual aligned pointer to payload. ValuePtr bitcastAligned = bitcastAllocated; @@ -1049,15 +1012,15 @@ struct AllocOpLowering : public LLVMLegalizationPattern { bitcastAligned = rewriter.create( loc, elementPtrType, ArrayRef(aligned)); } - memRefDescriptor.setAlignedPtr(rewriter, loc, bitcastAligned); + memRefDescriptor->setAlignedPtr(rewriter, loc, bitcastAligned); // Field 3: Offset in aligned pointer. - memRefDescriptor.setOffset(rewriter, loc, - createIndexConstant(rewriter, loc, offset)); + memRefDescriptor->setOffset(rewriter, loc, + createIndexConstant(rewriter, loc, offset)); if (type.getRank() == 0) // No size/stride descriptor in memref, return the descriptor value. - return rewriter.replaceOp(op, {memRefDescriptor}); + return rewriter.replaceOp(op, {memRefDescriptor->getValue()}); // Fields 4 and 5: Sizes and strides of the strided MemRef. // Store all sizes in the descriptor. Only dynamic sizes are passed in as @@ -1082,12 +1045,12 @@ struct AllocOpLowering : public LLVMLegalizationPattern { // Fill size and stride descriptors in memref. for (auto indexedSize : llvm::enumerate(sizes)) { int64_t index = indexedSize.index(); - memRefDescriptor.setSize(rewriter, loc, index, indexedSize.value()); - memRefDescriptor.setStride(rewriter, loc, index, strideValues[index]); + memRefDescriptor->setSize(rewriter, loc, index, indexedSize.value()); + memRefDescriptor->setStride(rewriter, loc, index, strideValues[index]); } // Return the final value of the descriptor. - rewriter.replaceOp(op, {memRefDescriptor}); + rewriter.replaceOp(op, {memRefDescriptor->getValue()}); } bool useAlloca; @@ -1193,10 +1156,10 @@ struct DeallocOpLowering : public LLVMLegalizationPattern { /*isVarArg=*/false)); } - MemRefDescriptor memref(transformed.memref()); + auto memref = lowering.createMemRefDescriptor(transformed.memref()); ValuePtr casted = rewriter.create( op->getLoc(), getVoidPtrType(), - memref.allocatedPtr(rewriter, op->getLoc())); + memref->allocatedPtr(rewriter, op->getLoc())); rewriter.replaceOpWithNewOp( op, ArrayRef(), rewriter.getSymbolRefAttr(freeFunc), casted); return matchSuccess(); @@ -1325,7 +1288,7 @@ struct MemRefCastOpLowering : public LLVMLegalizationPattern { memRefDesc.setRank(rewriter, loc, rankVal); // d2 = InsertValueOp d1, voidptr, 1 memRefDesc.setMemRefDescPtr(rewriter, loc, voidPtr); - rewriter.replaceOp(op, (ValuePtr)memRefDesc); + rewriter.replaceOp(op, memRefDesc.getValue()); } else if (srcType.isa() && dstType.isa()) { // Casting from unranked type to ranked. @@ -1365,9 +1328,11 @@ struct DimOpLowering : public LLVMLegalizationPattern { auto shape = type.getShape(); int64_t index = dimOp.getIndex(); // Extract dynamic size from the memref descriptor. - if (ShapedType::isDynamic(shape[index])) - rewriter.replaceOp(op, {MemRefDescriptor(transformed.memrefOrTensor()) - .size(rewriter, op->getLoc(), index)}); + if (ShapedType::isDynamic(shape[index])) { + auto memref = + lowering.createMemRefDescriptor(transformed.memrefOrTensor()); + rewriter.replaceOp(op, memref->size(rewriter, op->getLoc(), index)); + } else // Use constant for static size. rewriter.replaceOp( @@ -1420,21 +1385,19 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern { // This is a strided getElementPtr variant that linearizes subscripts as: // `base_offset + index_0 * stride_0 + ... + index_n * stride_n`. ValuePtr getStridedElementPtr(Location loc, Type elementTypePtr, - ValuePtr descriptor, ArrayRef indices, + MemRefDescriptor *memRefDescriptor, + ArrayRef indices, ArrayRef strides, int64_t offset, ConversionPatternRewriter &rewriter) const { - MemRefDescriptor memRefDescriptor(descriptor); - - ValuePtr base = memRefDescriptor.alignedPtr(rewriter, loc); - ValuePtr offsetValue = - offset == MemRefType::getDynamicStrideOrOffset() - ? memRefDescriptor.offset(rewriter, loc) - : this->createIndexConstant(rewriter, loc, offset); + ValuePtr base = memRefDescriptor->alignedPtr(rewriter, loc); + ValuePtr offsetValue = offset == MemRefType::getDynamicStrideOrOffset() + ? memRefDescriptor->offset(rewriter, loc) + : this->createIndexConstant(rewriter, loc, offset); for (int i = 0, e = indices.size(); i < e; ++i) { ValuePtr stride = strides[i] == MemRefType::getDynamicStrideOrOffset() - ? memRefDescriptor.stride(rewriter, loc, i) + ? memRefDescriptor->stride(rewriter, loc, i) : this->createIndexConstant(rewriter, loc, strides[i]); ValuePtr additionalOffset = rewriter.create(loc, indices[i], stride); @@ -1444,11 +1407,11 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern { return rewriter.create(loc, elementTypePtr, base, offsetValue); } - ValuePtr getDataPtr(Location loc, MemRefType type, ValuePtr memRefDesc, - ArrayRef indices, + ValuePtr getDataPtr(Location loc, MemRefType type, + MemRefDescriptor *memRefDesc, ArrayRef indices, ConversionPatternRewriter &rewriter, llvm::Module &module) const { - LLVM::LLVMType ptrType = MemRefDescriptor(memRefDesc).getElementType(); + LLVM::LLVMType ptrType = memRefDesc->getElementType(); int64_t offset; SmallVector strides; auto successStrides = getStridesAndOffset(type, strides, offset); @@ -1471,10 +1434,12 @@ struct LoadOpLowering : public LoadStoreOpLowering { OperandAdaptor transformed(operands); auto type = loadOp.getMemRefType(); - ValuePtr dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(), - transformed.indices(), rewriter, getModule()); + auto memrefDesc = lowering.createMemRefDescriptor(transformed.memref()); + ValuePtr dataPtr = + this->getDataPtr(op->getLoc(), type, memrefDesc.get(), + transformed.indices(), rewriter, this->getModule()); rewriter.replaceOpWithNewOp(op, dataPtr); - return matchSuccess(); + return this->matchSuccess(); } }; @@ -1489,11 +1454,13 @@ struct StoreOpLowering : public LoadStoreOpLowering { auto type = cast(op).getMemRefType(); OperandAdaptor transformed(operands); - ValuePtr dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(), - transformed.indices(), rewriter, getModule()); + auto memrefDesc = lowering.createMemRefDescriptor(transformed.memref()); + ValuePtr dataPtr = + this->getDataPtr(op->getLoc(), type, memrefDesc.get(), + transformed.indices(), rewriter, this->getModule()); rewriter.replaceOpWithNewOp(op, transformed.value(), dataPtr); - return matchSuccess(); + return this->matchSuccess(); } }; @@ -1509,7 +1476,8 @@ struct PrefetchOpLowering : public LoadStoreOpLowering { OperandAdaptor transformed(operands); auto type = prefetchOp.getMemRefType(); - ValuePtr dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(), + auto memrefDesc = lowering.createMemRefDescriptor(transformed.memref()); + ValuePtr dataPtr = getDataPtr(op->getLoc(), type, memrefDesc.get(), transformed.indices(), rewriter, getModule()); // Replace with llvm.prefetch. @@ -1850,25 +1818,26 @@ struct SubViewOpLowering : public LLVMLegalizationPattern { return matchFailure(); // Create the descriptor. - MemRefDescriptor sourceMemRef(operands.front()); - auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy); + auto sourceMemRef = lowering.createMemRefDescriptor(operands.front()); + auto targetMemRef = + lowering.buildMemRefDescriptor(rewriter, loc, targetDescTy); // Copy the buffer pointer from the old descriptor to the new one. - ValuePtr extracted = sourceMemRef.allocatedPtr(rewriter, loc); + ValuePtr extracted = sourceMemRef->allocatedPtr(rewriter, loc); ValuePtr bitcastPtr = rewriter.create( loc, targetElementTy.getPointerTo(), extracted); - targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr); + targetMemRef->setAllocatedPtr(rewriter, loc, bitcastPtr); - extracted = sourceMemRef.alignedPtr(rewriter, loc); + extracted = sourceMemRef->alignedPtr(rewriter, loc); bitcastPtr = rewriter.create( loc, targetElementTy.getPointerTo(), extracted); - targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr); + targetMemRef->setAlignedPtr(rewriter, loc, bitcastPtr); // Extract strides needed to compute offset. SmallVector strideValues; strideValues.reserve(viewMemRefType.getRank()); for (int i = 0, e = viewMemRefType.getRank(); i < e; ++i) - strideValues.push_back(sourceMemRef.stride(rewriter, loc, i)); + strideValues.push_back(sourceMemRef->stride(rewriter, loc, i)); // Fill in missing dynamic sizes. auto llvmIndexType = lowering.convertType(rewriter.getIndexType()); @@ -1882,18 +1851,18 @@ struct SubViewOpLowering : public LLVMLegalizationPattern { } // Offset. - ValuePtr baseOffset = sourceMemRef.offset(rewriter, loc); + ValuePtr baseOffset = sourceMemRef->offset(rewriter, loc); for (int i = 0, e = viewMemRefType.getRank(); i < e; ++i) { ValuePtr min = dynamicOffsets[i]; baseOffset = rewriter.create( loc, baseOffset, rewriter.create(loc, min, strideValues[i])); } - targetMemRef.setOffset(rewriter, loc, baseOffset); + targetMemRef->setOffset(rewriter, loc, baseOffset); // Update sizes and strides. for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) { - targetMemRef.setSize(rewriter, loc, i, dynamicSizes[i]); + targetMemRef->setSize(rewriter, loc, i, dynamicSizes[i]); ValuePtr newStride; if (dynamicStrides.empty()) newStride = rewriter.create( @@ -1901,10 +1870,10 @@ struct SubViewOpLowering : public LLVMLegalizationPattern { else newStride = rewriter.create(loc, dynamicStrides[i], strideValues[i]); - targetMemRef.setStride(rewriter, loc, i, newStride); + targetMemRef->setStride(rewriter, loc, i, newStride); } - rewriter.replaceOp(op, {targetMemRef}); + rewriter.replaceOp(op, {targetMemRef->getValue()}); return matchSuccess(); } }; @@ -1974,20 +1943,21 @@ struct ViewOpLowering : public LLVMLegalizationPattern { matchFailure(); // Create the descriptor. - MemRefDescriptor sourceMemRef(adaptor.source()); - auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy); + auto sourceMemRef = lowering.createMemRefDescriptor(adaptor.source()); + auto targetMemRef = + lowering.buildMemRefDescriptor(rewriter, loc, targetDescTy); // Field 1: Copy the allocated pointer, used for malloc/free. - ValuePtr extracted = sourceMemRef.allocatedPtr(rewriter, loc); + ValuePtr extracted = sourceMemRef->allocatedPtr(rewriter, loc); ValuePtr bitcastPtr = rewriter.create( loc, targetElementTy.getPointerTo(), extracted); - targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr); + targetMemRef->setAllocatedPtr(rewriter, loc, bitcastPtr); // Field 2: Copy the actual aligned pointer to payload. - extracted = sourceMemRef.alignedPtr(rewriter, loc); + extracted = sourceMemRef->alignedPtr(rewriter, loc); bitcastPtr = rewriter.create( loc, targetElementTy.getPointerTo(), extracted); - targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr); + targetMemRef->setAlignedPtr(rewriter, loc, bitcastPtr); // Field 3: Copy the offset in aligned pointer. unsigned numDynamicSizes = llvm::size(viewOp.getDynamicSizes()); @@ -1997,14 +1967,14 @@ struct ViewOpLowering : public LLVMLegalizationPattern { assert(llvm::size(sizeAndOffsetOperands) == numDynamicSizes + (hasDynamicOffset ? 1 : 0)); ValuePtr baseOffset = !hasDynamicOffset - ? createIndexConstant(rewriter, loc, offset) - // TODO(ntv): better adaptor. - : sizeAndOffsetOperands.front(); - targetMemRef.setOffset(rewriter, loc, baseOffset); + ? createIndexConstant(rewriter, loc, offset) + // TODO(ntv): better adaptor. + : sizeAndOffsetOperands.front(); + targetMemRef->setOffset(rewriter, loc, baseOffset); // Early exit for 0-D corner case. if (viewMemRefType.getRank() == 0) - return rewriter.replaceOp(op, {targetMemRef}), matchSuccess(); + return rewriter.replaceOp(op, {targetMemRef->getValue()}), matchSuccess(); // Fields 4 and 5: Update sizes and strides. if (strides.back() != 1) @@ -2019,14 +1989,14 @@ struct ViewOpLowering : public LLVMLegalizationPattern { // Update size. ValuePtr size = getSize(rewriter, loc, viewMemRefType.getShape(), sizeOperands, i); - targetMemRef.setSize(rewriter, loc, i, size); + targetMemRef->setSize(rewriter, loc, i, size); // Update stride. stride = getStride(rewriter, loc, strides, nextSize, stride, i); - targetMemRef.setStride(rewriter, loc, i, stride); + targetMemRef->setStride(rewriter, loc, i, stride); nextSize = size; } - rewriter.replaceOp(op, {targetMemRef}); + rewriter.replaceOp(op, {targetMemRef->getValue()}); return matchSuccess(); } }; @@ -2182,6 +2152,50 @@ Type LLVMTypeConverter::packFunctionResults(ArrayRef types) { return LLVM::LLVMType::getStructTy(llvmDialect, resultTypes); } +/// Create a DefaultMemRefDescriptor object for 'value'. +std::unique_ptr +LLVMTypeConverter::createMemRefDescriptor(ValuePtr value) { + return std::make_unique(value); +} + +/// Builds IR creating an `undef` value of the descriptor type. +std::unique_ptr +LLVMTypeConverter::buildMemRefDescriptor(OpBuilder &builder, Location loc, + Type descriptorType) { + ValuePtr descriptor = + builder.create(loc, descriptorType.cast()); + return std::make_unique(descriptor); +} + +/// Builds IR creating a MemRef descriptor that represents `type` and +/// populates it with static shape and stride information extracted from the +/// type. +std::unique_ptr +LLVMTypeConverter::buildStaticMemRefDescriptor(OpBuilder &builder, Location loc, + MemRefType type, ValuePtr memory) { + assert(type.hasStaticShape() && "unexpected dynamic shape"); + assert(type.getAffineMaps().empty() && "unexpected layout map"); + + auto convertedType = convertType(type); + assert(convertedType && "unexpected failure in memref type conversion"); + + auto descr = buildMemRefDescriptor(builder, loc, convertedType); + descr->setAllocatedPtr(builder, loc, memory); + descr->setAlignedPtr(builder, loc, memory); + descr->setConstantOffset(builder, loc, 0); + + // Fill in sizes and strides, in reverse order to simplify stride + // calculation. + uint64_t runningStride = 1; + for (unsigned i = type.getRank(); i > 0; --i) { + unsigned dim = i - 1; + descr->setConstantSize(builder, loc, dim, type.getDimSize(dim)); + descr->setConstantStride(builder, loc, dim, runningStride); + runningStride *= type.getDimSize(dim); + } + return descr; +} + ValuePtr LLVMTypeConverter::promoteOneMemRefDescriptor(Location loc, ValuePtr operand, OpBuilder &builder) { diff --git a/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 5099cb01bbc4..8d2a031fedb7 100644 --- a/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -554,7 +554,7 @@ class VectorTypeCastOpConversion : public LLVMOpLowering { operands[0]->getType().dyn_cast(); if (!llvmSourceDescriptorTy || !llvmSourceDescriptorTy.isStructTy()) return matchFailure(); - MemRefDescriptor sourceMemRef(operands[0]); + auto sourceMemRef = lowering.createMemRefDescriptor(operands[0]); auto llvmTargetDescriptorTy = lowering.convertType(targetMemRefType) .dyn_cast_or_null(); @@ -582,21 +582,22 @@ class VectorTypeCastOpConversion : public LLVMOpLowering { auto int64Ty = LLVM::LLVMType::getInt64Ty(lowering.getDialect()); // Create descriptor. - auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); - Type llvmTargetElementTy = desc.getElementType(); + auto desc = + lowering.buildMemRefDescriptor(rewriter, loc, llvmTargetDescriptorTy); + Type llvmTargetElementTy = desc->getElementType(); // Set allocated ptr. - ValuePtr allocated = sourceMemRef.allocatedPtr(rewriter, loc); + ValuePtr allocated = sourceMemRef->allocatedPtr(rewriter, loc); allocated = rewriter.create(loc, llvmTargetElementTy, allocated); - desc.setAllocatedPtr(rewriter, loc, allocated); + desc->setAllocatedPtr(rewriter, loc, allocated); // Set aligned ptr. - ValuePtr ptr = sourceMemRef.alignedPtr(rewriter, loc); + ValuePtr ptr = sourceMemRef->alignedPtr(rewriter, loc); ptr = rewriter.create(loc, llvmTargetElementTy, ptr); - desc.setAlignedPtr(rewriter, loc, ptr); + desc->setAlignedPtr(rewriter, loc, ptr); // Fill offset 0. auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0); auto zero = rewriter.create(loc, int64Ty, attr); - desc.setOffset(rewriter, loc, zero); + desc->setOffset(rewriter, loc, zero); // Fill size and stride descriptors in memref. for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) { @@ -604,14 +605,14 @@ class VectorTypeCastOpConversion : public LLVMOpLowering { auto sizeAttr = rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value()); auto size = rewriter.create(loc, int64Ty, sizeAttr); - desc.setSize(rewriter, loc, index, size); + desc->setSize(rewriter, loc, index, size); auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(), strides[index]); auto stride = rewriter.create(loc, int64Ty, strideAttr); - desc.setStride(rewriter, loc, index, stride); + desc->setStride(rewriter, loc, index, stride); } - rewriter.replaceOp(op, {desc}); + rewriter.replaceOp(op, {desc->getValue()}); return matchSuccess(); } }; diff --git a/test/Conversion/StandardToLLVM/convert-memref-ops.mlir b/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir similarity index 70% rename from test/Conversion/StandardToLLVM/convert-memref-ops.mlir rename to test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir index d92ded7f3aa9..2b5cff4cd04b 100644 --- a/test/Conversion/StandardToLLVM/convert-memref-ops.mlir +++ b/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir @@ -1,11 +1,12 @@ -// RUN: mlir-opt -convert-std-to-llvm %s | FileCheck %s -// RUN: mlir-opt -convert-std-to-llvm -convert-std-to-llvm-use-alloca=1 %s | FileCheck %s --check-prefix=ALLOCA +// RUN: mlir-opt -convert-std-to-llvm -split-input-file %s | FileCheck %s // CHECK-LABEL: func @check_arguments(%arg0: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">, %arg1: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">, %arg2: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">) func @check_arguments(%static: memref<10x20xf32>, %dynamic : memref, %mixed : memref<10x?xf32>) { return } +// ----- + // CHECK-LABEL: func @check_strided_memref_arguments( // CHECK-COUNT-3: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*"> func @check_strided_memref_arguments(%static: memref<10x20xf32, (i,j)->(20 * i + j + 1)>, @@ -14,73 +15,7 @@ func @check_strided_memref_arguments(%static: memref<10x20xf32, (i,j)->(20 * i + return } -// CHECK-LABEL: func @check_static_return(%arg0: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">) -> !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> { -func @check_static_return(%static : memref<32x18xf32>) -> memref<32x18xf32> { -// CHECK: llvm.return %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> - return %static : memref<32x18xf32> -} - -// CHECK-LABEL: func @zero_d_alloc() -> !llvm<"{ float*, float*, i64 }"> { -// ALLOCA-LABEL: func @zero_d_alloc() -> !llvm<"{ float*, float*, i64 }"> { -func @zero_d_alloc() -> memref { -// CHECK-NEXT: llvm.mlir.constant(1 : index) : !llvm.i64 -// CHECK-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm<"float*"> -// CHECK-NEXT: %[[one:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 -// CHECK-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> -// CHECK-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm<"float*"> to !llvm.i64 -// CHECK-NEXT: llvm.mul %{{.*}}, %[[sizeof]] : !llvm.i64 -// CHECK-NEXT: llvm.call @malloc(%{{.*}}) : (!llvm.i64) -> !llvm<"i8*"> -// CHECK-NEXT: %[[ptr:.*]] = llvm.bitcast %{{.*}} : !llvm<"i8*"> to !llvm<"float*"> -// CHECK-NEXT: llvm.mlir.undef : !llvm<"{ float*, float*, i64 }"> -// CHECK-NEXT: llvm.insertvalue %[[ptr]], %{{.*}}[0] : !llvm<"{ float*, float*, i64 }"> -// CHECK-NEXT: llvm.insertvalue %[[ptr]], %{{.*}}[1] : !llvm<"{ float*, float*, i64 }"> -// CHECK-NEXT: %[[c0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 -// CHECK-NEXT: llvm.insertvalue %[[c0]], %{{.*}}[2] : !llvm<"{ float*, float*, i64 }"> - -// ALLOCA-NOT: malloc -// ALLOCA: alloca -// ALLOCA-NOT: malloc - %0 = alloc() : memref - return %0 : memref -} - -// CHECK-LABEL: func @zero_d_dealloc(%{{.*}}: !llvm<"{ float*, float*, i64 }*">) { -func @zero_d_dealloc(%arg0: memref) { -// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64 }*"> -// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][0] : !llvm<"{ float*, float*, i64 }"> -// CHECK-NEXT: %[[bc:.*]] = llvm.bitcast %[[ptr]] : !llvm<"float*"> to !llvm<"i8*"> -// CHECK-NEXT: llvm.call @free(%[[bc]]) : (!llvm<"i8*">) -> () - dealloc %arg0 : memref - return -} - -// CHECK-LABEL: func @aligned_1d_alloc( -func @aligned_1d_alloc() -> memref<42xf32> { -// CHECK-NEXT: llvm.mlir.constant(42 : index) : !llvm.i64 -// CHECK-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm<"float*"> -// CHECK-NEXT: %[[one:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 -// CHECK-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> -// CHECK-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm<"float*"> to !llvm.i64 -// CHECK-NEXT: llvm.mul %{{.*}}, %[[sizeof]] : !llvm.i64 -// CHECK-NEXT: %[[alignment:.*]] = llvm.mlir.constant(8 : index) : !llvm.i64 -// CHECK-NEXT: %[[alignmentMinus1:.*]] = llvm.add {{.*}}, %[[alignment]] : !llvm.i64 -// CHECK-NEXT: %[[allocsize:.*]] = llvm.sub %[[alignmentMinus1]], %[[one]] : !llvm.i64 -// CHECK-NEXT: %[[allocated:.*]] = llvm.call @malloc(%[[allocsize]]) : (!llvm.i64) -> !llvm<"i8*"> -// CHECK-NEXT: %[[ptr:.*]] = llvm.bitcast %{{.*}} : !llvm<"i8*"> to !llvm<"float*"> -// CHECK-NEXT: llvm.mlir.undef : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> -// CHECK-NEXT: llvm.insertvalue %[[ptr]], %{{.*}}[0] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> -// CHECK-NEXT: %[[allocatedAsInt:.*]] = llvm.ptrtoint %[[allocated]] : !llvm<"i8*"> to !llvm.i64 -// CHECK-NEXT: %[[alignAdj1:.*]] = llvm.urem %[[allocatedAsInt]], %[[alignment]] : !llvm.i64 -// CHECK-NEXT: %[[alignAdj2:.*]] = llvm.sub %[[alignment]], %[[alignAdj1]] : !llvm.i64 -// CHECK-NEXT: %[[alignAdj3:.*]] = llvm.urem %[[alignAdj2]], %[[alignment]] : !llvm.i64 -// CHECK-NEXT: %[[aligned:.*]] = llvm.getelementptr %9[%[[alignAdj3]]] : (!llvm<"i8*">, !llvm.i64) -> !llvm<"i8*"> -// CHECK-NEXT: %[[alignedBitCast:.*]] = llvm.bitcast %[[aligned]] : !llvm<"i8*"> to !llvm<"float*"> -// CHECK-NEXT: llvm.insertvalue %[[alignedBitCast]], %{{.*}}[1] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> -// CHECK-NEXT: %[[c0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 -// CHECK-NEXT: llvm.insertvalue %[[c0]], %{{.*}}[2] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> - %0 = alloc() {alignment = 8} : memref<42xf32> - return %0 : memref<42xf32> -} +// ----- // CHECK-LABEL: func @mixed_alloc( // CHECK: %[[M:.*]]: !llvm.i64, %[[N:.*]]: !llvm.i64) -> !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }"> { @@ -114,6 +49,8 @@ func @mixed_alloc(%arg0: index, %arg1: index) -> memref { return %0 : memref } +// ----- + // CHECK-LABEL: func @mixed_dealloc(%arg0: !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }*">) { func @mixed_dealloc(%arg0: memref) { // CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }*"> @@ -125,6 +62,8 @@ func @mixed_dealloc(%arg0: memref) { return } +// ----- + // CHECK-LABEL: func @dynamic_alloc( // CHECK: %[[M:.*]]: !llvm.i64, %[[N:.*]]: !llvm.i64) -> !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> { func @dynamic_alloc(%arg0: index, %arg1: index) -> memref { @@ -152,6 +91,8 @@ func @dynamic_alloc(%arg0: index, %arg1: index) -> memref { return %0 : memref } +// ----- + // CHECK-LABEL: func @dynamic_dealloc(%arg0: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">) { func @dynamic_dealloc(%arg0: memref) { // CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*"> @@ -162,60 +103,7 @@ func @dynamic_dealloc(%arg0: memref) { return } -// CHECK-LABEL: func @static_alloc() -> !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> { -func @static_alloc() -> memref<32x18xf32> { -// CHECK-NEXT: %[[sz1:.*]] = llvm.mlir.constant(32 : index) : !llvm.i64 -// CHECK-NEXT: %[[sz2:.*]] = llvm.mlir.constant(18 : index) : !llvm.i64 -// CHECK-NEXT: %[[num_elems:.*]] = llvm.mul %0, %1 : !llvm.i64 -// CHECK-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm<"float*"> -// CHECK-NEXT: %[[one:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 -// CHECK-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> -// CHECK-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm<"float*"> to !llvm.i64 -// CHECK-NEXT: %[[bytes:.*]] = llvm.mul %[[num_elems]], %[[sizeof]] : !llvm.i64 -// CHECK-NEXT: %[[allocated:.*]] = llvm.call @malloc(%[[bytes]]) : (!llvm.i64) -> !llvm<"i8*"> -// CHECK-NEXT: llvm.bitcast %[[allocated]] : !llvm<"i8*"> to !llvm<"float*"> - %0 = alloc() : memref<32x18xf32> - return %0 : memref<32x18xf32> -} - -// CHECK-LABEL: func @static_dealloc(%{{.*}}: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">) { -func @static_dealloc(%static: memref<10x8xf32>) { -// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*"> -// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> -// CHECK-NEXT: %[[bc:.*]] = llvm.bitcast %[[ptr]] : !llvm<"float*"> to !llvm<"i8*"> -// CHECK-NEXT: llvm.call @free(%[[bc]]) : (!llvm<"i8*">) -> () - dealloc %static : memref<10x8xf32> - return -} - -// CHECK-LABEL: func @zero_d_load(%{{.*}}: !llvm<"{ float*, float*, i64 }*">) -> !llvm.float { -func @zero_d_load(%arg0: memref) -> f32 { -// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64 }*"> -// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][1] : !llvm<"{ float*, float*, i64 }"> -// CHECK-NEXT: %[[c0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 -// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[c0]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> -// CHECK-NEXT: %{{.*}} = llvm.load %[[addr]] : !llvm<"float*"> - %0 = load %arg0[] : memref - return %0 : f32 -} - -// CHECK-LABEL: func @static_load( -// CHECK: %[[A:.*]]: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">, %[[I:.*]]: !llvm.i64, %[[J:.*]]: !llvm.i64 -func @static_load(%static : memref<10x42xf32>, %i : index, %j : index) { -// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*"> -// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> -// CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 -// CHECK-NEXT: %[[st0:.*]] = llvm.mlir.constant(42 : index) : !llvm.i64 -// CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64 -// CHECK-NEXT: %[[off0:.*]] = llvm.add %[[off]], %[[offI]] : !llvm.i64 -// CHECK-NEXT: %[[st1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 -// CHECK-NEXT: %[[offJ:.*]] = llvm.mul %[[J]], %[[st1]] : !llvm.i64 -// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[off0]], %[[offJ]] : !llvm.i64 -// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> -// CHECK-NEXT: llvm.load %[[addr]] : !llvm<"float*"> - %0 = load %static[%i, %j] : memref<10x42xf32> - return -} +// ----- // CHECK-LABEL: func @mixed_load( // CHECK: %[[A:.*]]: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">, %[[I:.*]]: !llvm.i64, %[[J:.*]]: !llvm.i64 @@ -235,6 +123,8 @@ func @mixed_load(%mixed : memref<42x?xf32>, %i : index, %j : index) { return } +// ----- + // CHECK-LABEL: func @dynamic_load( // CHECK: %[[A:.*]]: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">, %[[I:.*]]: !llvm.i64, %[[J:.*]]: !llvm.i64 func @dynamic_load(%dynamic : memref, %i : index, %j : index) { @@ -253,6 +143,8 @@ func @dynamic_load(%dynamic : memref, %i : index, %j : index) { return } +// ----- + // CHECK-LABEL: func @prefetch func @prefetch(%A : memref, %i : index, %j : index) { // CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*"> @@ -311,6 +203,8 @@ func @static_store(%static : memref<10x42xf32>, %i : index, %j : index, %val : f return } +// ----- + // CHECK-LABEL: func @dynamic_store func @dynamic_store(%dynamic : memref, %i : index, %j : index, %val : f32) { // CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*"> @@ -328,6 +222,8 @@ func @dynamic_store(%dynamic : memref, %i : index, %j : index, %val : f return } +// ----- + // CHECK-LABEL: func @mixed_store func @mixed_store(%mixed : memref<42x?xf32>, %i : index, %j : index, %val : f32) { // CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*"> @@ -345,6 +241,8 @@ func @mixed_store(%mixed : memref<42x?xf32>, %i : index, %j : index, %val : f32) return } +// ----- + // CHECK-LABEL: func @memref_cast_static_to_dynamic func @memref_cast_static_to_dynamic(%static : memref<10x42xf32>) { // CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*"> @@ -353,6 +251,8 @@ func @memref_cast_static_to_dynamic(%static : memref<10x42xf32>) { return } +// ----- + // CHECK-LABEL: func @memref_cast_static_to_mixed func @memref_cast_static_to_mixed(%static : memref<10x42xf32>) { // CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*"> @@ -361,6 +261,8 @@ func @memref_cast_static_to_mixed(%static : memref<10x42xf32>) { return } +// ----- + // CHECK-LABEL: func @memref_cast_dynamic_to_static func @memref_cast_dynamic_to_static(%dynamic : memref) { // CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*"> @@ -369,6 +271,8 @@ func @memref_cast_dynamic_to_static(%dynamic : memref) { return } +// ----- + // CHECK-LABEL: func @memref_cast_dynamic_to_mixed func @memref_cast_dynamic_to_mixed(%dynamic : memref) { // CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*"> @@ -377,6 +281,8 @@ func @memref_cast_dynamic_to_mixed(%dynamic : memref) { return } +// ----- + // CHECK-LABEL: func @memref_cast_mixed_to_dynamic func @memref_cast_mixed_to_dynamic(%mixed : memref<42x?xf32>) { // CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*"> @@ -385,6 +291,8 @@ func @memref_cast_mixed_to_dynamic(%mixed : memref<42x?xf32>) { return } +// ----- + // CHECK-LABEL: func @memref_cast_mixed_to_static func @memref_cast_mixed_to_static(%mixed : memref<42x?xf32>) { // CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*"> @@ -393,6 +301,8 @@ func @memref_cast_mixed_to_static(%mixed : memref<42x?xf32>) { return } +// ----- + // CHECK-LABEL: func @memref_cast_mixed_to_mixed func @memref_cast_mixed_to_mixed(%mixed : memref<42x?xf32>) { // CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*"> @@ -401,6 +311,8 @@ func @memref_cast_mixed_to_mixed(%mixed : memref<42x?xf32>) { return } +// ----- + // CHECK-LABEL: func @memref_cast_ranked_to_unranked func @memref_cast_ranked_to_unranked(%arg : memref<42x2x?xf32>) { // CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }*"> @@ -416,6 +328,8 @@ func @memref_cast_ranked_to_unranked(%arg : memref<42x2x?xf32>) { return } +// ----- + // CHECK-LABEL: func @memref_cast_unranked_to_ranked func @memref_cast_unranked_to_ranked(%arg : memref<*xf32>) { // CHECK: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ i64, i8* }*"> @@ -425,6 +339,8 @@ func @memref_cast_unranked_to_ranked(%arg : memref<*xf32>) { return } +// ----- + // CHECK-LABEL: func @mixed_memref_dim(%arg0: !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }*">) { func @mixed_memref_dim(%mixed : memref<42x?x?x13x?xf32>) { // CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }*"> @@ -441,19 +357,3 @@ func @mixed_memref_dim(%mixed : memref<42x?x?x13x?xf32>) { return } -// CHECK-LABEL: func @static_memref_dim(%arg0: !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }*">) { -func @static_memref_dim(%static : memref<42x32x15x13x27xf32>) { -// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }*"> -// CHECK-NEXT: llvm.mlir.constant(42 : index) : !llvm.i64 - %0 = dim %static, 0 : memref<42x32x15x13x27xf32> -// CHECK-NEXT: llvm.mlir.constant(32 : index) : !llvm.i64 - %1 = dim %static, 1 : memref<42x32x15x13x27xf32> -// CHECK-NEXT: llvm.mlir.constant(15 : index) : !llvm.i64 - %2 = dim %static, 2 : memref<42x32x15x13x27xf32> -// CHECK-NEXT: llvm.mlir.constant(13 : index) : !llvm.i64 - %3 = dim %static, 3 : memref<42x32x15x13x27xf32> -// CHECK-NEXT: llvm.mlir.constant(27 : index) : !llvm.i64 - %4 = dim %static, 4 : memref<42x32x15x13x27xf32> - return -} - diff --git a/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir b/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir new file mode 100644 index 000000000000..1c0b7440a914 --- /dev/null +++ b/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir @@ -0,0 +1,310 @@ +// RUN: mlir-opt -convert-std-to-llvm -split-input-file %s | FileCheck %s +// RUN: mlir-opt -convert-std-to-llvm -convert-std-to-llvm-use-alloca=1 %s | FileCheck %s --check-prefix=ALLOCA +// RUN: mlir-opt -test-custom-memref-llvm-lowering %s | FileCheck %s --check-prefix=CUSTOM + +// CUSTOM-LABEL: func @check_noalias +// CUSTOM-SAME: arg0: !llvm<"float**"> {llvm.noalias = true} +func @check_noalias(%static : memref<2xf32> {llvm.noalias = true}) { + return +} + +// ----- + +// CHECK-LABEL: func @check_static_return(%arg0: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">) -> !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> { +// CUSTOM-LABEL: func @check_static_return +// CUSTOM-SAME: (%arg0: !llvm<"float**">) -> !llvm<"float*"> +func @check_static_return(%static : memref<32x18xf32>) -> memref<32x18xf32> { +// CHECK: llvm.return %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// CUSTOM: llvm.return %{{.*}} : !llvm<"float*"> + return %static : memref<32x18xf32> +} + +// ----- + +// CHECK-LABEL: func @zero_d_alloc() -> !llvm<"{ float*, float*, i64 }"> { +// ALLOCA-LABEL: func @zero_d_alloc() -> !llvm<"{ float*, float*, i64 }"> { +// CUSTOM-LABEL: func @zero_d_alloc() -> !llvm<"float*"> +func @zero_d_alloc() -> memref { +// CHECK-NEXT: llvm.mlir.constant(1 : index) : !llvm.i64 +// CHECK-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm<"float*"> +// CHECK-NEXT: %[[one:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 +// CHECK-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> +// CHECK-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm<"float*"> to !llvm.i64 +// CHECK-NEXT: llvm.mul %{{.*}}, %[[sizeof]] : !llvm.i64 +// CHECK-NEXT: llvm.call @malloc(%{{.*}}) : (!llvm.i64) -> !llvm<"i8*"> +// CHECK-NEXT: %[[ptr:.*]] = llvm.bitcast %{{.*}} : !llvm<"i8*"> to !llvm<"float*"> +// CHECK-NEXT: llvm.mlir.undef : !llvm<"{ float*, float*, i64 }"> +// CHECK-NEXT: llvm.insertvalue %[[ptr]], %{{.*}}[0] : !llvm<"{ float*, float*, i64 }"> +// CHECK-NEXT: llvm.insertvalue %[[ptr]], %{{.*}}[1] : !llvm<"{ float*, float*, i64 }"> +// CHECK-NEXT: %[[c0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 +// CHECK-NEXT: llvm.insertvalue %[[c0]], %{{.*}}[2] : !llvm<"{ float*, float*, i64 }"> + +// ALLOCA-NOT: malloc +// ALLOCA: alloca +// ALLOCA-NOT: malloc + +// CUSTOM: llvm.mlir.constant(1 : index) : !llvm.i64 +// CUSTOM: %[[null:.*]] = llvm.mlir.null : !llvm<"float*"> +// CUSTOM: %[[one:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 +// CUSTOM: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> +// CUSTOM: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm<"float*"> to !llvm.i64 +// CUSTOM: %[[size:.*]] = llvm.mul %{{.*}}, %[[sizeof]] : !llvm.i64 +// CUSTOM: llvm.call @malloc(%[[size]]) : (!llvm.i64) -> !llvm<"i8*"> +// CUSTOM: %[[ptr:.*]] = llvm.bitcast %{{.*}} : !llvm<"i8*"> to !llvm<"float*"> +// CUSTOM: llvm.return %[[ptr]] : !llvm<"float*"> + %0 = alloc() : memref + return %0 : memref +} + +// ----- + +// CHECK-LABEL: func @zero_d_dealloc(%{{.*}}: !llvm<"{ float*, float*, i64 }*">) { +// CUSTOM-LABEL: func @zero_d_dealloc +// CUSTOM-SAME: (%{{.*}}: !llvm<"float**">) { +func @zero_d_dealloc(%arg0: memref) { +// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64 }*"> +// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][0] : !llvm<"{ float*, float*, i64 }"> +// CHECK-NEXT: %[[bc:.*]] = llvm.bitcast %[[ptr]] : !llvm<"float*"> to !llvm<"i8*"> +// CHECK-NEXT: llvm.call @free(%[[bc]]) : (!llvm<"i8*">) -> () + +// CUSTOM: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"float**"> +// CUSTOM: %[[bc:.*]] = llvm.bitcast %[[ld]] : !llvm<"float*"> to !llvm<"i8*"> +// CUSTOM: llvm.call @free(%[[bc]]) : (!llvm<"i8*">) -> () + dealloc %arg0 : memref + return +} + +// ----- + +// CHECK-LABEL: func @aligned_1d_alloc( +// CUSTOM-LABEL: func @aligned_1d_alloc( +func @aligned_1d_alloc() -> memref<42xf32> { +// CHECK-NEXT: llvm.mlir.constant(42 : index) : !llvm.i64 +// CHECK-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm<"float*"> +// CHECK-NEXT: %[[one:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 +// CHECK-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> +// CHECK-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm<"float*"> to !llvm.i64 +// CHECK-NEXT: llvm.mul %{{.*}}, %[[sizeof]] : !llvm.i64 +// CHECK-NEXT: %[[alignment:.*]] = llvm.mlir.constant(8 : index) : !llvm.i64 +// CHECK-NEXT: %[[alignmentMinus1:.*]] = llvm.add {{.*}}, %[[alignment]] : !llvm.i64 +// CHECK-NEXT: %[[allocsize:.*]] = llvm.sub %[[alignmentMinus1]], %[[one]] : !llvm.i64 +// CHECK-NEXT: %[[allocated:.*]] = llvm.call @malloc(%[[allocsize]]) : (!llvm.i64) -> !llvm<"i8*"> +// CHECK-NEXT: %[[ptr:.*]] = llvm.bitcast %{{.*}} : !llvm<"i8*"> to !llvm<"float*"> +// CHECK-NEXT: llvm.mlir.undef : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> +// CHECK-NEXT: llvm.insertvalue %[[ptr]], %{{.*}}[0] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> +// CHECK-NEXT: %[[allocatedAsInt:.*]] = llvm.ptrtoint %[[allocated]] : !llvm<"i8*"> to !llvm.i64 +// CHECK-NEXT: %[[alignAdj1:.*]] = llvm.urem %[[allocatedAsInt]], %[[alignment]] : !llvm.i64 +// CHECK-NEXT: %[[alignAdj2:.*]] = llvm.sub %[[alignment]], %[[alignAdj1]] : !llvm.i64 +// CHECK-NEXT: %[[alignAdj3:.*]] = llvm.urem %[[alignAdj2]], %[[alignment]] : !llvm.i64 +// CHECK-NEXT: %[[aligned:.*]] = llvm.getelementptr %{{.*}}[%[[alignAdj3]]] : (!llvm<"i8*">, !llvm.i64) -> !llvm<"i8*"> +// CHECK-NEXT: %[[alignedBitCast:.*]] = llvm.bitcast %[[aligned]] : !llvm<"i8*"> to !llvm<"float*"> +// CHECK-NEXT: llvm.insertvalue %[[alignedBitCast]], %{{.*}}[1] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> +// CHECK-NEXT: %[[c0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 +// CHECK-NEXT: llvm.insertvalue %[[c0]], %{{.*}}[2] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> + +// CUSTOM-NEXT: llvm.mlir.constant(42 : index) : !llvm.i64 +// CUSTOM-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm<"float*"> +// CUSTOM-NEXT: %[[one:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 +// CUSTOM-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> +// CUSTOM-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm<"float*"> to !llvm.i64 +// CUSTOM-NEXT: llvm.mul %{{.*}}, %[[sizeof]] : !llvm.i64 +// CUSTOM-NEXT: %[[alignment:.*]] = llvm.mlir.constant(8 : index) : !llvm.i64 +// CUSTOM-NEXT: %[[alignmentMinus1:.*]] = llvm.add %5{{.*}}, %[[alignment]] : !llvm.i64 +// CUSTOM-NEXT: %[[allocsize:.*]] = llvm.sub %[[alignmentMinus1]], %[[one]] : !llvm.i64 +// CUSTOM-NEXT: %[[allocated:.*]] = llvm.call @malloc(%[[allocsize]]) : (!llvm.i64) -> !llvm<"i8*"> +// CUSTOM-NEXT: %[[ptr:.*]] = llvm.bitcast %{{.*}} : !llvm<"i8*"> to !llvm<"float*"> +// CUSTOM-NEXT: %[[allocatedAsInt:.*]] = llvm.ptrtoint %[[allocated]] : !llvm<"i8*"> to !llvm.i64 +// CUSTOM-NEXT: %[[alignAdj1:.*]] = llvm.urem %[[allocatedAsInt]], %[[alignment]] : !llvm.i64 +// CUSTOM-NEXT: %[[alignAdj2:.*]] = llvm.sub %[[alignment]], %[[alignAdj1]] : !llvm.i64 +// CUSTOM-NEXT: %[[alignAdj3:.*]] = llvm.urem %[[alignAdj2]], %[[alignment]] : !llvm.i64 +// CUSTOM-NEXT: %[[aligned:.*]] = llvm.getelementptr %{{.*}}[%[[alignAdj3]]] : (!llvm<"i8*">, !llvm.i64) -> !llvm<"i8*"> +// CUSTOM-NEXT: %[[alignedBitCast:.*]] = llvm.bitcast %15 : !llvm<"i8*"> to !llvm<"float*"> +// Alignment is not implemented in custom lowering so base ptr is returned. +// CUSTOM: llvm.return %[[ptr]] : !llvm<"float*"> + %0 = alloc() {alignment = 8} : memref<42xf32> + return %0 : memref<42xf32> +} + +// ----- + +// CHECK-LABEL: func @static_alloc() -> !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> { +// CUSTOM-LABEL: func @static_alloc() +// CUSTOM-SAME: -> !llvm<"float*"> +func @static_alloc() -> memref<32x18xf32> { +// CHECK-NEXT: %[[sz1:.*]] = llvm.mlir.constant(32 : index) : !llvm.i64 +// CHECK-NEXT: %[[sz2:.*]] = llvm.mlir.constant(18 : index) : !llvm.i64 +// CHECK-NEXT: %[[num_elems:.*]] = llvm.mul %[[sz1]], %[[sz2]] : !llvm.i64 +// CHECK-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm<"float*"> +// CHECK-NEXT: %[[one:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 +// CHECK-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> +// CHECK-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm<"float*"> to !llvm.i64 +// CHECK-NEXT: %[[bytes:.*]] = llvm.mul %[[num_elems]], %[[sizeof]] : !llvm.i64 +// CHECK-NEXT: %[[allocated:.*]] = llvm.call @malloc(%[[bytes]]) : (!llvm.i64) -> !llvm<"i8*"> +// CHECK-NEXT: llvm.bitcast %[[allocated]] : !llvm<"i8*"> to !llvm<"float*"> + +// CUSTOM-NEXT: %[[sz1:.*]] = llvm.mlir.constant(32 : index) : !llvm.i64 +// CUSTOM-NEXT: %[[sz2:.*]] = llvm.mlir.constant(18 : index) : !llvm.i64 +// CUSTOM-NEXT: %[[num_elems:.*]] = llvm.mul %[[sz1]], %[[sz2]] : !llvm.i64 +// CUSTOM-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm<"float*"> +// CUSTOM-NEXT: %[[one:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 +// CUSTOM-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> +// CUSTOM-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm<"float*"> to !llvm.i64 +// CUSTOM-NEXT: %[[bytes:.*]] = llvm.mul %[[num_elems]], %[[sizeof]] : !llvm.i64 +// CUSTOM-NEXT: %[[allocated:.*]] = llvm.call @malloc(%[[bytes]]) : (!llvm.i64) -> !llvm<"i8*"> +// CUSTOM-NEXT: %[[bc:.*]] = llvm.bitcast %[[allocated]] : !llvm<"i8*"> to !llvm<"float*"> +// CUSTOM: llvm.return %[[bc]] : !llvm<"float*"> + %0 = alloc() : memref<32x18xf32> + return %0 : memref<32x18xf32> +} + +// ----- + +// CHECK-LABEL: func @static_dealloc(%{{.*}}: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">) { +// CUSTOM-LABEL: func @static_dealloc +// CUSTOM-SAME: (%{{.*}}: !llvm<"float**">) +func @static_dealloc(%static: memref<10x8xf32>) { +// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*"> +// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// CHECK-NEXT: %[[bc:.*]] = llvm.bitcast %[[ptr]] : !llvm<"float*"> to !llvm<"i8*"> +// CHECK-NEXT: llvm.call @free(%[[bc]]) : (!llvm<"i8*">) -> () + +// CUSTOM-NEXT: %[[ld:.*]] = llvm.load %arg0 : !llvm<"float**"> +// CUSTOM-NEXT: %[[bc:.*]] = llvm.bitcast %[[ld]] : !llvm<"float*"> to !llvm<"i8*"> +// CUSTOM-NEXT: llvm.call @free(%[[bc]]) : (!llvm<"i8*">) -> () + dealloc %static : memref<10x8xf32> + return +} + +// ----- + +// CHECK-LABEL: func @zero_d_load(%{{.*}}: !llvm<"{ float*, float*, i64 }*">) -> !llvm.float { +// CUSTOM-LABEL: func @zero_d_load +// CUSTOM-SAME: (%{{.*}}: !llvm<"float**">) -> !llvm.float +func @zero_d_load(%arg0: memref) -> f32 { +// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64 }*"> +// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][1] : !llvm<"{ float*, float*, i64 }"> +// CHECK-NEXT: %[[c0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 +// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[c0]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> +// CHECK-NEXT: llvm.load %[[addr]] : !llvm<"float*"> + +// CUSTOM-NEXT: %[[ptr:.*]] = llvm.load %arg0 : !llvm<"float**"> +// CUSTOM-NEXT: %[[c0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 +// CUSTOM-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[c0]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> +// CUSTOM-NEXT: llvm.load %[[addr:.*]] : !llvm<"float*"> + %0 = load %arg0[] : memref + return %0 : f32 +} + +// ----- + +// CHECK-LABEL: func @static_load( +// CHECK: %[[A:.*]]: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">, %[[I:.*]]: !llvm.i64, %[[J:.*]]: !llvm.i64 + +// CUSTOM-LABEL: func @static_load +// CUSTOM-SAME: (%[[A:.*]]: !llvm<"float**">, %[[I:.*]]: !llvm.i64, %[[J:.*]]: !llvm.i64) { +func @static_load(%static : memref<10x42xf32>, %i : index, %j : index) { +// CHECK-NEXT: %[[ld:.*]] = llvm.load %[[A]] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*"> +// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 +// CHECK-NEXT: %[[st0:.*]] = llvm.mlir.constant(42 : index) : !llvm.i64 +// CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64 +// CHECK-NEXT: %[[off0:.*]] = llvm.add %[[off]], %[[offI]] : !llvm.i64 +// CHECK-NEXT: %[[st1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 +// CHECK-NEXT: %[[offJ:.*]] = llvm.mul %[[J]], %[[st1]] : !llvm.i64 +// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[off0]], %[[offJ]] : !llvm.i64 +// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> +// CHECK-NEXT: llvm.load %[[addr]] : !llvm<"float*"> + +// CUSTOM-NEXT: %[[ptr:.*]] = llvm.load %[[A]] : !llvm<"float**"> +// CUSTOM-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 +// CUSTOM-NEXT: %[[st0:.*]] = llvm.mlir.constant(42 : index) : !llvm.i64 +// CUSTOM-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64 +// CUSTOM-NEXT: %[[off0:.*]] = llvm.add %[[off]], %[[offI]] : !llvm.i64 +// CUSTOM-NEXT: %[[st1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 +// CUSTOM-NEXT: %[[offJ:.*]] = llvm.mul %[[J]], %[[st1]] : !llvm.i64 +// CUSTOM-NEXT: %[[off1:.*]] = llvm.add %[[off0]], %[[offJ]] : !llvm.i64 +// CUSTOM-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> +// CUSTOM-NEXT: llvm.load %[[addr]] : !llvm<"float*"> + %0 = load %static[%i, %j] : memref<10x42xf32> + return +} + +// ----- + +// CHECK-LABEL: func @zero_d_store +// CHECK-SAME: (%[[A:.*]]: !llvm<"{ float*, float*, i64 }*">, %[[val:.*]]: !llvm.float) +// CUSTOM-LABEL: func @zero_d_store +// CUSTOM-SAME: (%[[A:.*]]: !llvm<"float**">, %[[val:.*]]: !llvm.float) +func @zero_d_store(%arg0: memref, %arg1: f32) { +// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64 }*"> +// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][1] : !llvm<"{ float*, float*, i64 }"> +// CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 +// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> +// CHECK-NEXT: llvm.store %[[val]], %[[addr]] : !llvm<"float*"> + +// CUSTOM-NEXT: %[[ptr:.*]] = llvm.load %[[A]] : !llvm<"float**"> +// CUSTOM-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 +// CUSTOM-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> +// CUSTOM-NEXT: llvm.store %[[val]], %[[addr]] : !llvm<"float*"> + store %arg1, %arg0[] : memref + return +} + +// ----- + +// CHECK-LABEL: func @static_store +// CUSTOM-LABEL: func @static_store +func @static_store(%static : memref<10x42xf32>, %i : index, %j : index, %val : f32) { +// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*"> +// CHECK-NEXT: %[[ptr:.*]] = llvm.extractvalue %[[ld]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 +// CHECK-NEXT: %[[st0:.*]] = llvm.mlir.constant(42 : index) : !llvm.i64 +// CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64 +// CHECK-NEXT: %[[off0:.*]] = llvm.add %[[off]], %[[offI]] : !llvm.i64 +// CHECK-NEXT: %[[st1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 +// CHECK-NEXT: %[[offJ:.*]] = llvm.mul %[[J]], %[[st1]] : !llvm.i64 +// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[off0]], %[[offJ]] : !llvm.i64 +// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> +// CHECK-NEXT: llvm.store %{{.*}}, %[[addr]] : !llvm<"float*"> + +// CUSTOM-NEXT: %[[ptr:.*]] = llvm.load %{{.*}} : !llvm<"float**"> +// CUSTOM-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 +// CUSTOM-NEXT: %[[st0:.*]] = llvm.mlir.constant(42 : index) : !llvm.i64 +// CUSTOM-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64 +// CUSTOM-NEXT: %[[off0:.*]] = llvm.add %[[off]], %[[offI]] : !llvm.i64 +// CUSTOM-NEXT: %[[st1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 +// CUSTOM-NEXT: %[[offJ:.*]] = llvm.mul %[[J]], %[[st1]] : !llvm.i64 +// CUSTOM-NEXT: %[[off1:.*]] = llvm.add %[[off0]], %[[offJ]] : !llvm.i64 +// CUSTOM-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> +// CUSTOM-NEXT: llvm.store %{{.*}}, %[[addr]] : !llvm<"float*"> + store %val, %static[%i, %j] : memref<10x42xf32> + return +} + +// ----- + +// CHECK-LABEL: func @static_memref_dim(%arg0: !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }*">) { +// CUSTOM-LABEL: func @static_memref_dim +// CUSTOM-SAME: (%arg0: !llvm<"float**">) +func @static_memref_dim(%static : memref<42x32x15x13x27xf32>) { +// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }*"> +// CHECK-NEXT: llvm.mlir.constant(42 : index) : !llvm.i64 + +// CUSTOM-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"float**"> +// CUSTOM-NEXT: llvm.mlir.constant(42 : index) : !llvm.i64 + %0 = dim %static, 0 : memref<42x32x15x13x27xf32> +// CHECK-NEXT: llvm.mlir.constant(32 : index) : !llvm.i64 +// CUSTOM-NEXT: llvm.mlir.constant(32 : index) : !llvm.i64 + %1 = dim %static, 1 : memref<42x32x15x13x27xf32> +// CHECK-NEXT: llvm.mlir.constant(15 : index) : !llvm.i64 +// CUSTOM-NEXT: llvm.mlir.constant(15 : index) : !llvm.i64 + %2 = dim %static, 2 : memref<42x32x15x13x27xf32> +// CHECK-NEXT: llvm.mlir.constant(13 : index) : !llvm.i64 +// CUSTOM-NEXT: llvm.mlir.constant(13 : index) : !llvm.i64 + %3 = dim %static, 3 : memref<42x32x15x13x27xf32> +// CHECK-NEXT: llvm.mlir.constant(27 : index) : !llvm.i64 +// CUSTOM-NEXT: llvm.mlir.constant(27 : index) : !llvm.i64 + %4 = dim %static, 4 : memref<42x32x15x13x27xf32> + return +} + diff --git a/test/lib/Transforms/CMakeLists.txt b/test/lib/Transforms/CMakeLists.txt index b6338e1d167c..1c1a3a8d4d49 100644 --- a/test/lib/Transforms/CMakeLists.txt +++ b/test/lib/Transforms/CMakeLists.txt @@ -1,6 +1,7 @@ add_llvm_library(MLIRTestTransforms TestCallGraph.cpp TestConstantFold.cpp + TestCustomMemRefLLVMLowering.cpp TestLoopFusion.cpp TestInlining.cpp TestLinalgTransforms.cpp diff --git a/test/lib/Transforms/TestCustomMemRefLLVMLowering.cpp b/test/lib/Transforms/TestCustomMemRefLLVMLowering.cpp new file mode 100644 index 000000000000..dd1b232b256c --- /dev/null +++ b/test/lib/Transforms/TestCustomMemRefLLVMLowering.cpp @@ -0,0 +1,190 @@ +//===- TestCustomMemRefLLVMLowering.cpp - Pass to test strides +// computation--===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/Passes.h" + +using namespace mlir; + +namespace { +/// Test pass that lowers MemRef type to LLVM using a custom descriptor. +struct TestCustomMemRefLLVMLowering + : public ModulePass { + void runOnModule() override; +}; + +/// Custom MemRef descriptor that lowers MemRef types to LLVM plain pointers. +/// Alignment and dynamic shapes are currently not supported. +class CustomMemRefDescriptor : public MemRefDescriptor { +public: + /// Construct a helper for the given descriptor value. + explicit CustomMemRefDescriptor(Value *descriptor) : value(descriptor){}; + + Value *getValue() override { return value; } + + /// Builds IR extracting the allocated pointer from the descriptor. + Value *allocatedPtr(OpBuilder &builder, Location loc) override { + return value; + }; + /// Builds IR inserting the allocated pointer into the descriptor. + void setAllocatedPtr(OpBuilder &builder, Location loc, Value *ptr) override { + value = ptr; + }; + + /// Builds IR extracting the aligned pointer from the descriptor. + Value *alignedPtr(OpBuilder &builder, Location loc) override { + return allocatedPtr(builder, loc); + }; + + /// Builds IR inserting the aligned pointer into the descriptor. + void setAlignedPtr(OpBuilder &builder, Location loc, Value *ptr) override{ + // Alignment is not supported by this memref descriptor. + // 'alignedPtr' returns allocatedPtr instead. + }; + + /// Builds IR extracting the offset from the descriptor. + Value *offset(OpBuilder &builder, Location loc) override { + llvm_unreachable("'offset' is not implemented in CustomMemRefDescriptor"); + }; + + /// Builds IR inserting the offset into the descriptor. + void setOffset(OpBuilder &builder, Location loc, Value *offset) override{}; + + void setConstantOffset(OpBuilder &builder, Location loc, + uint64_t offset) override{}; + + /// Builds IR extracting the pos-th size from the descriptor. + Value *size(OpBuilder &builder, Location loc, unsigned pos) override { + llvm_unreachable("'size' is not implemented in CustomMemRefDescriptor"); + }; + + /// Builds IR inserting the pos-th size into the descriptor + void setSize(OpBuilder &builder, Location loc, unsigned pos, + Value *size) override{}; + void setConstantSize(OpBuilder &builder, Location loc, unsigned pos, + uint64_t size) override{}; + + /// Builds IR extracting the pos-th size from the descriptor. + Value *stride(OpBuilder &builder, Location loc, unsigned pos) override { + llvm_unreachable("'stride' is not implemented in CustomMemRefDescriptor"); + }; + + /// Builds IR inserting the pos-th stride into the descriptor + void setStride(OpBuilder &builder, Location loc, unsigned pos, + Value *stride) override{}; + void setConstantStride(OpBuilder &builder, Location loc, unsigned pos, + uint64_t stride) override{}; + + /// Returns the (LLVM) type this descriptor points to. + LLVM::LLVMType getElementType() override { + return value->getType().cast(); + } + +private: + Value *value; +}; + +/// Provides Std-to-LLVM type conversion by using CustomMemRefDescriptor to +/// lower MemRef types. Falls back to base LLVMTypeConverter for the remaining +/// types. +class CustomLLVMTypeConverter : public mlir::LLVMTypeConverter { +public: + using LLVMTypeConverter::LLVMTypeConverter; + + Type convertType(Type type) override { + if (auto memrefTy = type.dyn_cast()) { + return convertMemRefType(memrefTy); + } + + // Fall back to base class converter. + return LLVMTypeConverter::convertType(type); + } + + /// Creates a CustomMemRefDescriptor object for 'value'. + std::unique_ptr + createMemRefDescriptor(Value *value) override { + return std::make_unique(value); + } + + /// Creates a CustomMemRefDescriptor object for an uninitialized descriptor + /// (nullptr value). No new IR is needed for such initialization. + std::unique_ptr + buildMemRefDescriptor(OpBuilder &builder, Location loc, + Type descriptorType) override { + return createMemRefDescriptor(nullptr); + } + + /// Builds IR creating a MemRef descriptor that represents `type` and + /// populates it with static shape and stride information extracted from the + /// type. + std::unique_ptr + buildStaticMemRefDescriptor(OpBuilder &builder, Location loc, MemRefType type, + Value *memory) override { + assert(type.hasStaticShape() && "unexpected dynamic shape"); + assert(type.getAffineMaps().empty() && "unexpected layout map"); + + auto convertedType = convertType(type); + assert(convertedType && "unexpected failure in memref type conversion"); + + auto descr = buildMemRefDescriptor(builder, loc, convertedType); + descr->setAllocatedPtr(builder, loc, memory); + return descr; + } + +private: + /// Converts MemRef type to plain LLVM pointer to element type. + Type convertMemRefType(MemRefType type) { + int64_t offset; + SmallVector strides; + bool strideSuccess = succeeded(getStridesAndOffset(type, strides, offset)); + assert(strideSuccess && + "Non-strided layout maps must have been normalized away"); + (void)strideSuccess; + + LLVM::LLVMType elementType = unwrap(convertType(type.getElementType())); + if (!elementType) + return {}; + auto ptrTy = elementType.getPointerTo(type.getMemorySpace()); + return ptrTy; + } +}; + +} // end anonymous namespace + +void TestCustomMemRefLLVMLowering::runOnModule() { + // Populate Std-to-LLVM conversion patterns using the custom type converter. + CustomLLVMTypeConverter typeConverter(&getContext()); + OwningRewritePatternList patterns; + populateStdToLLVMConversionPatterns(typeConverter, patterns); + + ConversionTarget target(getContext()); + target.addLegalDialect(); + if (failed(applyPartialConversion(getModule(), target, patterns, + &typeConverter))) + signalPassFailure(); +} + +static PassRegistration + pass("test-custom-memref-llvm-lowering", + "Test custom LLVM lowering of memrefs"); +