Skip to content

Commit

Permalink
[VectorOps] Add lowering of vector.insert to LLVM IR
Browse files Browse the repository at this point in the history
For example, an insert

  %0 = vector.insert %arg0, %arg1[3 : i32] : f32 into vector<4xf32>

becomes

  %0 = llvm.mlir.constant(3 : i32) : !llvm.i32
  %1 = llvm.insertelement %arg0, %arg1[%0 : !llvm.i32] : !llvm<"<4 x float>">

A more elaborate example, inserting an element in a higher dimension
vector

  %0 = vector.insert %arg0, %arg1[3 : i32, 7 : i32, 15 : i32] : f32 into vector<4x8x16xf32>

becomes

  %0 = llvm.extractvalue %arg1[3 : i32, 7 : i32] : !llvm<"[4 x [8 x <16 x float>]]">
  %1 = llvm.mlir.constant(15 : i32) : !llvm.i32
  %2 = llvm.insertelement %arg0, %0[%1 : !llvm.i32] : !llvm<"<16 x float>">
  %3 = llvm.insertvalue %2, %arg1[3 : i32, 7 : i32] : !llvm<"[4 x [8 x <16 x float>]]">

PiperOrigin-RevId: 284882443
  • Loading branch information
aartbik authored and tensorflower-gardener committed Dec 11, 2019
1 parent 4d8ba88 commit 9826fe5
Show file tree
Hide file tree
Showing 2 changed files with 153 additions and 24 deletions.
124 changes: 100 additions & 24 deletions mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,18 @@ static LLVM::LLVMType getPtrToElementType(T containerType,
.getPointerTo();
}

// Helper to reduce vector type by one rank at front.
static VectorType reducedVectorTypeFront(VectorType tp) {
assert((tp.getRank() > 1) && "unlowerable vector type");
return VectorType::get(tp.getShape().drop_front(), tp.getElementType());
}

// Helper to reduce vector type by *all* but one rank at back.
static VectorType reducedVectorTypeBack(VectorType tp) {
assert((tp.getRank() > 1) && "unlowerable vector type");
return VectorType::get(tp.getShape().take_back(), tp.getElementType());
}

class VectorBroadcastOpConversion : public LLVMOpLowering {
public:
explicit VectorBroadcastOpConversion(MLIRContext *context,
Expand Down Expand Up @@ -135,8 +147,9 @@ class VectorBroadcastOpConversion : public LLVMOpLowering {
return rewriter.create<LLVM::ShuffleVectorOp>(
loc, expand, undef, rewriter.getI32ArrayAttr(zeroValues));
}
Value *expand = expandRanks(value, loc, srcVectorType,
reducedVectorType(dstVectorType), rewriter);
Value *expand =
expandRanks(value, loc, srcVectorType,
reducedVectorTypeFront(dstVectorType), rewriter);
Value *result = rewriter.create<LLVM::UndefOp>(loc, llvmType);
for (int64_t d = 0; d < dim; ++d) {
result = insertOne(result, expand, loc, llvmType, rank, d, rewriter);
Expand Down Expand Up @@ -183,8 +196,8 @@ class VectorBroadcastOpConversion : public LLVMOpLowering {
result = insertOne(result, one, loc, llvmType, rank, d, rewriter);
}
} else {
VectorType redSrcType = reducedVectorType(srcVectorType);
VectorType redDstType = reducedVectorType(dstVectorType);
VectorType redSrcType = reducedVectorTypeFront(srcVectorType);
VectorType redDstType = reducedVectorTypeFront(dstVectorType);
Type redLlvmType = lowering.convertType(redSrcType);
for (int64_t d = 0; d < dim; ++d) {
int64_t pos = atStretch ? 0 : d;
Expand Down Expand Up @@ -226,18 +239,12 @@ class VectorBroadcastOpConversion : public LLVMOpLowering {
return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, value,
rewriter.getI64ArrayAttr(pos));
}

// Helper to reduce vector type by one rank.
static VectorType reducedVectorType(VectorType tp) {
assert((tp.getRank() > 1) && "unlowerable vector type");
return VectorType::get(tp.getShape().drop_front(), tp.getElementType());
}
};

class VectorExtractElementOpConversion : public LLVMOpLowering {
class VectorExtractOpConversion : public LLVMOpLowering {
public:
explicit VectorExtractElementOpConversion(MLIRContext *context,
LLVMTypeConverter &typeConverter)
explicit VectorExtractOpConversion(MLIRContext *context,
LLVMTypeConverter &typeConverter)
: LLVMOpLowering(vector::ExtractOp::getOperationName(), context,
typeConverter) {}

Expand All @@ -247,11 +254,15 @@ class VectorExtractElementOpConversion : public LLVMOpLowering {
auto loc = op->getLoc();
auto adaptor = vector::ExtractOpOperandAdaptor(operands);
auto extractOp = cast<vector::ExtractOp>(op);
auto vectorType = extractOp.vector()->getType().cast<VectorType>();
auto vectorType = extractOp.getVectorType();
auto resultType = extractOp.getResult()->getType();
auto llvmResultType = lowering.convertType(resultType);

auto positionArrayAttr = extractOp.position();

// Bail if result type cannot be lowered.
if (!llvmResultType)
return matchFailure();

// One-shot extraction of vector from array (only requires extractvalue).
if (resultType.isa<VectorType>()) {
Value *extracted = rewriter.create<LLVM::ExtractValueOp>(
Expand All @@ -260,15 +271,12 @@ class VectorExtractElementOpConversion : public LLVMOpLowering {
return matchSuccess();
}

// Potential extraction of 1-D vector from struct.
// Potential extraction of 1-D vector from array.
auto *context = op->getContext();
Value *extracted = adaptor.vector();
auto positionAttrs = positionArrayAttr.getValue();
auto i32Type = rewriter.getIntegerType(32);
if (positionAttrs.size() > 1) {
auto nDVectorType = vectorType;
auto oneDVectorType = VectorType::get(nDVectorType.getShape().take_back(),
nDVectorType.getElementType());
auto oneDVectorType = reducedVectorTypeBack(vectorType);
auto nMinusOnePositionAttrs =
ArrayAttr::get(positionAttrs.drop_back(), context);
extracted = rewriter.create<LLVM::ExtractValueOp>(
Expand All @@ -278,8 +286,8 @@ class VectorExtractElementOpConversion : public LLVMOpLowering {

// Remaining extraction of element from 1-D LLVM vector
auto position = positionAttrs.back().cast<IntegerAttr>();
auto constant = rewriter.create<LLVM::ConstantOp>(
loc, lowering.convertType(i32Type), position);
auto i32Type = LLVM::LLVMType::getInt32Ty(lowering.getDialect());
auto constant = rewriter.create<LLVM::ConstantOp>(loc, i32Type, position);
extracted =
rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
rewriter.replaceOp(op, extracted);
Expand All @@ -288,6 +296,73 @@ class VectorExtractElementOpConversion : public LLVMOpLowering {
}
};

class VectorInsertOpConversion : public LLVMOpLowering {
public:
explicit VectorInsertOpConversion(MLIRContext *context,
LLVMTypeConverter &typeConverter)
: LLVMOpLowering(vector::InsertOp::getOperationName(), context,
typeConverter) {}

PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
auto adaptor = vector::InsertOpOperandAdaptor(operands);
auto insertOp = cast<vector::InsertOp>(op);
auto sourceType = insertOp.getSourceType();
auto destVectorType = insertOp.getDestVectorType();
auto llvmResultType = lowering.convertType(destVectorType);
auto positionArrayAttr = insertOp.position();

// Bail if result type cannot be lowered.
if (!llvmResultType)
return matchFailure();

// One-shot insertion of a vector into an array (only requires insertvalue).
if (sourceType.isa<VectorType>()) {
Value *inserted = rewriter.create<LLVM::InsertValueOp>(
loc, llvmResultType, adaptor.dest(), adaptor.source(),
positionArrayAttr);
rewriter.replaceOp(op, inserted);
return matchSuccess();
}

// Potential extraction of 1-D vector from array.
auto *context = op->getContext();
Value *extracted = adaptor.dest();
auto positionAttrs = positionArrayAttr.getValue();
auto position = positionAttrs.back().cast<IntegerAttr>();
auto oneDVectorType = destVectorType;
if (positionAttrs.size() > 1) {
oneDVectorType = reducedVectorTypeBack(destVectorType);
auto nMinusOnePositionAttrs =
ArrayAttr::get(positionAttrs.drop_back(), context);
extracted = rewriter.create<LLVM::ExtractValueOp>(
loc, lowering.convertType(oneDVectorType), extracted,
nMinusOnePositionAttrs);
}

// Insertion of an element into a 1-D LLVM vector.
auto i32Type = LLVM::LLVMType::getInt32Ty(lowering.getDialect());
auto constant = rewriter.create<LLVM::ConstantOp>(loc, i32Type, position);
Value *inserted = rewriter.create<LLVM::InsertElementOp>(
loc, lowering.convertType(oneDVectorType), extracted, adaptor.source(),
constant);

// Potential insertion of resulting 1-D vector into array.
if (positionAttrs.size() > 1) {
auto nMinusOnePositionAttrs =
ArrayAttr::get(positionAttrs.drop_back(), context);
inserted = rewriter.create<LLVM::InsertValueOp>(loc, llvmResultType,
adaptor.dest(), inserted,
nMinusOnePositionAttrs);
}

rewriter.replaceOp(op, inserted);
return matchSuccess();
}
};

class VectorOuterProductOpConversion : public LLVMOpLowering {
public:
explicit VectorOuterProductOpConversion(MLIRContext *context,
Expand Down Expand Up @@ -431,8 +506,9 @@ class VectorTypeCastOpConversion : public LLVMOpLowering {
/// Populate the given list with patterns that convert from Vector to LLVM.
void mlir::populateVectorToLLVMConversionPatterns(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
patterns.insert<VectorBroadcastOpConversion, VectorExtractElementOpConversion,
VectorOuterProductOpConversion, VectorTypeCastOpConversion>(
patterns.insert<VectorBroadcastOpConversion, VectorExtractOpConversion,
VectorInsertOpConversion, VectorOuterProductOpConversion,
VectorTypeCastOpConversion>(
converter.getDialect()->getContext(), converter);
}

Expand Down
53 changes: 53 additions & 0 deletions mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,15 @@ func @outerproduct_add(%arg0: vector<2xf32>, %arg1: vector<3xf32>, %arg2: vector
// CHECK: llvm.insertvalue {{.*}}[1] : !llvm<"[2 x <3 x float>]">
// CHECK: llvm.return {{.*}} : !llvm<"[2 x <3 x float>]">

func @extract_element_from_vec_1d(%arg0: vector<16xf32>) -> f32 {
%0 = vector.extract %arg0[15 : i32]: vector<16xf32>
return %0 : f32
}
// CHECK-LABEL: extract_element_from_vec_1d
// CHECK: llvm.mlir.constant(15 : i32) : !llvm.i32
// CHECK: llvm.extractelement {{.*}}[{{.*}} : !llvm.i32] : !llvm<"<16 x float>">
// CHECK: llvm.return {{.*}} : !llvm.float

func @extract_vec_2d_from_vec_3d(%arg0: vector<4x3x16xf32>) -> vector<3x16xf32> {
%0 = vector.extract %arg0[0 : i32]: vector<4x3x16xf32>
return %0 : vector<3x16xf32>
Expand All @@ -238,6 +247,14 @@ func @extract_vec_2d_from_vec_3d(%arg0: vector<4x3x16xf32>) -> vector<3x16xf32>
// CHECK: llvm.extractvalue {{.*}}[0 : i32] : !llvm<"[4 x [3 x <16 x float>]]">
// CHECK: llvm.return {{.*}} : !llvm<"[3 x <16 x float>]">

func @extract_vec_1d_from_vec_3d(%arg0: vector<4x3x16xf32>) -> vector<16xf32> {
%0 = vector.extract %arg0[0 : i32, 0 : i32]: vector<4x3x16xf32>
return %0 : vector<16xf32>
}
// CHECK-LABEL: extract_vec_1d_from_vec_3d
// CHECK: llvm.extractvalue {{.*}}[0 : i32, 0 : i32] : !llvm<"[4 x [3 x <16 x float>]]">
// CHECK: llvm.return {{.*}} : !llvm<"<16 x float>">

func @extract_element_from_vec_3d(%arg0: vector<4x3x16xf32>) -> f32 {
%0 = vector.extract %arg0[0 : i32, 0 : i32, 0 : i32]: vector<4x3x16xf32>
return %0 : f32
Expand All @@ -248,6 +265,42 @@ func @extract_element_from_vec_3d(%arg0: vector<4x3x16xf32>) -> f32 {
// CHECK: llvm.extractelement {{.*}}[{{.*}} : !llvm.i32] : !llvm<"<16 x float>">
// CHECK: llvm.return {{.*}} : !llvm.float

func @insert_element_into_vec_1d(%arg0: f32, %arg1: vector<4xf32>) -> vector<4xf32> {
%0 = vector.insert %arg0, %arg1[3 : i32] : f32 into vector<4xf32>
return %0 : vector<4xf32>
}
// CHECK-LABEL: insert_element_into_vec_1d
// CHECK: llvm.mlir.constant(3 : i32) : !llvm.i32
// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i32] : !llvm<"<4 x float>">
// CHECK: llvm.return {{.*}} : !llvm<"<4 x float>">

func @insert_vec_2d_into_vec_3d(%arg0: vector<8x16xf32>, %arg1: vector<4x8x16xf32>) -> vector<4x8x16xf32> {
%0 = vector.insert %arg0, %arg1[3 : i32] : vector<8x16xf32> into vector<4x8x16xf32>
return %0 : vector<4x8x16xf32>
}
// CHECK-LABEL: insert_vec_2d_into_vec_3d
// CHECK: llvm.insertvalue {{.*}}, {{.*}}[3 : i32] : !llvm<"[4 x [8 x <16 x float>]]">
// CHECK: llvm.return {{.*}} : !llvm<"[4 x [8 x <16 x float>]]">

func @insert_vec_1d_into_vec_3d(%arg0: vector<16xf32>, %arg1: vector<4x8x16xf32>) -> vector<4x8x16xf32> {
%0 = vector.insert %arg0, %arg1[3 : i32, 7 : i32] : vector<16xf32> into vector<4x8x16xf32>
return %0 : vector<4x8x16xf32>
}
// CHECK-LABEL: insert_vec_1d_into_vec_3d
// CHECK: llvm.insertvalue {{.*}}, {{.*}}[3 : i32, 7 : i32] : !llvm<"[4 x [8 x <16 x float>]]">
// CHECK: llvm.return {{.*}} : !llvm<"[4 x [8 x <16 x float>]]">

func @insert_element_into_vec_3d(%arg0: f32, %arg1: vector<4x8x16xf32>) -> vector<4x8x16xf32> {
%0 = vector.insert %arg0, %arg1[3 : i32, 7 : i32, 15 : i32] : f32 into vector<4x8x16xf32>
return %0 : vector<4x8x16xf32>
}
// CHECK-LABEL: insert_element_into_vec_3d
// CHECK: llvm.extractvalue {{.*}}[3 : i32, 7 : i32] : !llvm<"[4 x [8 x <16 x float>]]">
// CHECK: llvm.mlir.constant(15 : i32) : !llvm.i32
// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i32] : !llvm<"<16 x float>">
// CHECK: llvm.insertvalue {{.*}}, {{.*}}[3 : i32, 7 : i32] : !llvm<"[4 x [8 x <16 x float>]]">
// CHECK: llvm.return {{.*}} : !llvm<"[4 x [8 x <16 x float>]]">

func @vector_type_cast(%arg0: memref<8x8x8xf32>) -> memref<vector<8x8x8xf32>> {
%0 = vector.type_cast %arg0: memref<8x8x8xf32> to memref<vector<8x8x8xf32>>
return %0 : memref<vector<8x8x8xf32>>
Expand Down

0 comments on commit 9826fe5

Please sign in to comment.