Skip to content

Commit

Permalink
[MLIR][TORCH] Add E2E support for aten.bitwise_and.tensor op
Browse files Browse the repository at this point in the history
This commit adds lowering of `aten.bitwise_and.tensor` op.

Signed-Off By: Vivek Khandelwal [email protected]
  • Loading branch information
vivekkhandelwal1 committed Dec 2, 2021
1 parent 46a0668 commit 46a2189
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 14 deletions.
21 changes: 21 additions & 0 deletions e2e_testing/torchscript/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,3 +552,24 @@ def forward(self, x):
@register_test_case(module_factory=lambda: ElementwiseDivScalarModule())
def ElementwiseDivScalarModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4))

# ==============================================================================
class ElementwiseAndIntegerModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1], torch.int32, True),
([-1, -1], torch.int64, True),
])

def forward(self, x, y):
return torch.bitwise_and(x, y)


@register_test_case(module_factory=lambda: ElementwiseAndIntegerModule())
def ElementwiseAndIntegerModule_basic(module, tu: TestUtils):
module.forward(torch.randint(-10, 10, (3, 4)).to(torch.int32),
torch.randint(-10, 10, (3, 4)))
15 changes: 15 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1494,6 +1494,21 @@ def Torch_Aten_SoftmaxOp : Torch_Op<"aten._softmax", [
let assemblyFormat = "$self `,` $dim `,` $half_to_float attr-dict `:` type($self) `,` type($dim) `,` type($half_to_float) `->` type($result)";
}

def Torch_AtenBitwiseAndTensorOp : Torch_Op<"aten.bitwise_and.Tensor", [
AllowsTypeRefinement,
HasValueSemantics
]> {
let summary = "Generated op for `aten::bitwise_and.Tensor : (Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$other
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self `,` $other attr-dict `:` type($self) `,` type($other) `->` type($result)";
}

def Torch_AtenUnsqueezeOp : Torch_Op<"aten.unsqueeze", [
AllowsTypeRefinement
]> {
Expand Down
43 changes: 30 additions & 13 deletions lib/Conversion/TorchToLinalg/TorchToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1394,6 +1394,22 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
return b.create<math::SqrtOp>(loc, payloadArgs[0]);
if (isa<AtenRsqrtOp>(op))
return b.create<math::RsqrtOp>(loc, payloadArgs[0]);
if (auto bitwiseAndTensor = dyn_cast<AtenBitwiseAndTensorOp>(op)) {
if (bitwiseAndTensor.getType()
.cast<ValueTensorType>()
.getDtype()
.isa<mlir::FloatType>()) {
bitwiseAndTensor.emitError(
"Bitwise_And does not support floating point dtype");
return nullptr;
}
Type dtype = converter->convertType(bitwiseAndTensor.getType())
.cast<RankedTensorType>()
.getElementType();
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
return b.create<arith::AndIOp>(loc, lhs, rhs);
}
if (isa<AtenLog2Op>(op))
return b.create<math::Log2Op>(loc, payloadArgs[0]);
if (isa<AtenAbsOp>(op))
Expand Down Expand Up @@ -1895,13 +1911,14 @@ struct ConvertElementwiseOp : ConversionPattern {
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
if (!isa<AtenTanhOp, AtenReluOp, AtenLeakyReluOp, AtenGeluOp, AtenGeluBackwardOp,
AtenAddTensorOp, AtenMulTensorOp, AtenDivTensorOp, AtenSubTensorOp,
AtenLerpTensorOp, AtenSigmoidOp, AtenExpOp, AtenMinimumOp,
AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp,
AtenMulScalarOp, AtenLogOp, AtenSqrtOp, AtenFloorOp,
AtenPowTensorScalarOp, AtenLog2Op, AtenRsqrtOp, AtenDivScalarOp,
AtenAbsOp, AtenReciprocalOp>(op))
if (!isa<AtenTanhOp, AtenReluOp, AtenLeakyReluOp, AtenGeluOp,
AtenGeluBackwardOp, AtenAddTensorOp, AtenMulTensorOp,
AtenDivTensorOp, AtenSubTensorOp, AtenLerpTensorOp, AtenSigmoidOp,
AtenExpOp, AtenMinimumOp, AtenMaximumOp, AtenToDtypeOp,
AtenClampOp, AtenRsubScalarOp, AtenMulScalarOp, AtenLogOp,
AtenSqrtOp, AtenFloorOp, AtenPowTensorScalarOp, AtenLog2Op,
AtenRsqrtOp, AtenDivScalarOp, AtenAbsOp, AtenReciprocalOp,
AtenBitwiseAndTensorOp>(op))
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");

if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
Expand Down Expand Up @@ -3193,12 +3210,12 @@ class ConvertTorchToLinalg
target.addIllegalOp<AtenBatchNormOp>();
patterns.add<ConvertAtenBatchNormOp>(typeConverter, context);
target.addIllegalOp<
AtenTanhOp, AtenReluOp, AtenLeakyReluOp, AtenGeluOp, AtenGeluBackwardOp, AtenAddTensorOp,
AtenMulTensorOp, AtenDivTensorOp, AtenSubTensorOp, AtenLerpTensorOp,
AtenSigmoidOp, AtenMinimumOp, AtenMaximumOp, AtenToDtypeOp, AtenClampOp,
AtenRsubScalarOp, AtenLogOp, AtenSqrtOp, AtenFloorOp,
AtenPowTensorScalarOp, AtenLog2Op, AtenRsqrtOp, AtenAbsOp,
AtenReciprocalOp>();
AtenTanhOp, AtenReluOp, AtenLeakyReluOp, AtenGeluOp, AtenGeluBackwardOp,
AtenAddTensorOp, AtenMulTensorOp, AtenDivTensorOp, AtenSubTensorOp,
AtenLerpTensorOp, AtenSigmoidOp, AtenMinimumOp, AtenMaximumOp,
AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp, AtenLogOp, AtenSqrtOp,
AtenFloorOp, AtenPowTensorScalarOp, AtenLog2Op, AtenRsqrtOp, AtenAbsOp,
AtenReciprocalOp, AtenBitwiseAndTensorOp>();
patterns.add<ConvertElementwiseOp>(typeConverter, context);
target.addIllegalOp<AtenSqueezeOp>();
patterns.add<ConvertAtenSqueezeOp>(typeConverter, context);
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/Torch/Transforms/RefineTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ class TypeAnalyzer : public ForwardDataFlowAnalysis<ValueKnowledge> {
return visitBinaryTensorScalarOp(op, operands);
} else if (isa<AtenAddTensorOp, AtenSubTensorOp, AtenMulTensorOp,
AtenDivTensorOp, Aten__And__TensorOp, AtenEqTensorOp,
AtenMinimumOp, AtenMaximumOp>(op)) {
AtenMinimumOp, AtenMaximumOp, AtenBitwiseAndTensorOp>(op)) {
return visitBinaryBroadcastingOp(op, operands);
} else if (auto lerpTensor = llvm::dyn_cast<AtenLerpTensorOp>(op)) {
return visitAtenLerpTensorOp(lerpTensor, operands);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,8 @@ def emit_with_mutating_variants(key, **kwargs):
"aten::rsqrt : (Tensor) -> (Tensor)",
"aten::abs : (Tensor) -> (Tensor)",
"aten::reciprocal : (Tensor) -> (Tensor)",
"aten::bitwise_and.Tensor : (Tensor, Tensor) -> (Tensor)",

]:
emit_with_mutating_variants(key)
# Elementwise tensor compute ops that don't have the standard mutating
Expand Down

0 comments on commit 46a2189

Please sign in to comment.