Skip to content

Commit

Permalink
[Torch] add fold logic for some ops (#3794)
Browse files Browse the repository at this point in the history
  • Loading branch information
yyp0 authored Oct 16, 2024
1 parent 6b289f2 commit dc7a1ff
Show file tree
Hide file tree
Showing 4 changed files with 254 additions and 2 deletions.
4 changes: 4 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -3425,6 +3425,7 @@ def Torch_AtenDivTensorModeOp : Torch_Op<"aten.div.Tensor_mode", [
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
let hasFolder = 1;
let hasCanonicalizer = 1;
}

Expand Down Expand Up @@ -4902,6 +4903,7 @@ def Torch_AtenRsubScalarOp : Torch_Op<"aten.rsub.Scalar", [
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
let hasFolder = 1;
let hasCanonicalizer = 1;
}

Expand Down Expand Up @@ -12641,6 +12643,7 @@ def Torch_AtenIntTensorOp : Torch_Op<"aten.Int.Tensor", [
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
let hasFolder = 1;
let hasCanonicalizer = 1;
}

Expand Down Expand Up @@ -15334,6 +15337,7 @@ def Torch_AtenRemainderScalarOp : Torch_Op<"aten.remainder.Scalar", [
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
let hasFolder = 1;
}

def Torch_AtenRemainderTensorOp : Torch_Op<"aten.remainder.Tensor", [
Expand Down
134 changes: 134 additions & 0 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1535,6 +1535,24 @@ void AtenRsubScalarOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
});
}

// ===----------------------------------------------------------------------===//
// AtenRSubScalarOp
// ===----------------------------------------------------------------------===//

OpFoldResult AtenRsubScalarOp::fold(FoldAdaptor adaptor) {
auto fpFold = [](llvm::ArrayRef<double> inputs) {
assert(inputs.size() == 3);
return inputs[1] - inputs[0] * inputs[2];
};

auto intFold = [](llvm::ArrayRef<APInt> inputs) {
assert(inputs.size() == 3);
return inputs[1] - inputs[0] * inputs[2];
};

return naryFolderHelper(adaptor.getOperands(), getType(), fpFold, intFold);
}

//===----------------------------------------------------------------------===//
// AtenMulTensorOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1979,6 +1997,58 @@ void AtenDivTensorModeOp::getCanonicalizationPatterns(
});
}

// ===----------------------------------------------------------------------===//
// AtenDivTensorModeOp
// ===----------------------------------------------------------------------===//

OpFoldResult AtenDivTensorModeOp::fold(FoldAdaptor adaptor) {
auto resultTy = dyn_cast_or_null<ValueTensorType>(getType());
if (!resultTy || !resultTy.hasDtype()) {
return nullptr;
}
std::function<double(ArrayRef<double>)> fpFold;
std::function<APInt(ArrayRef<APInt>)> intFold;

auto roundMode = dyn_cast_or_null<StringAttr>(adaptor.getRoundingMode());
auto unsign = false;
if (isa<mlir::IntegerType>(resultTy.getDtype())) {
unsign = cast<IntegerType>(resultTy.getDtype()).isUnsigned();
}

fpFold = [roundMode](llvm::ArrayRef<double> inputs) {
assert(inputs.size() == 2);
if (!roundMode) {
return (double)inputs[0] / inputs[1];
} else if (roundMode.getValue().str() == "floor") {
return std::floor((double)inputs[0] / inputs[1]);
} else {
return std::trunc((double)inputs[0] / inputs[1]);
}
};

intFold = [unsign, roundMode](llvm::ArrayRef<APInt> inputs) {
assert(inputs.size() == 2);
auto lhs = unsign ? inputs[0].getZExtValue() : inputs[0].getSExtValue();
auto rhs = unsign ? inputs[1].getZExtValue() : inputs[1].getSExtValue();
int64_t bits = std::max(inputs[0].getBitWidth(), inputs[1].getBitWidth());
int64_t res;
if (roundMode.getValue().str() == "floor") {
res = std::floor(lhs / rhs);
} else {
res = std::trunc(lhs / rhs);
}
return APInt(bits, res);
};

if (!roundMode) {
return naryFolderHelper({adaptor.getSelf(), adaptor.getOther()}, getType(),
fpFold, std::nullopt);
}

return naryFolderHelper({adaptor.getSelf(), adaptor.getOther()}, getType(),
fpFold, intFold);
}

//===----------------------------------------------------------------------===//
// AtenDivScalarModeOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -3597,6 +3667,34 @@ OpFoldResult AtenRemainderIntOp::fold(FoldAdaptor adaptor) {
adaptor.getOperands(), [](int64_t a, int64_t b) { return a % b; });
}

// ===----------------------------------------------------------------------===//
// AtenRemainderScalarOp
// ===----------------------------------------------------------------------===//

OpFoldResult AtenRemainderScalarOp::fold(FoldAdaptor adaptor) {
auto resultTy = dyn_cast_or_null<ValueTensorType>(getType());
if (!resultTy || !resultTy.hasDtype()) {
return nullptr;
}

auto unsign = false;
if (isa<mlir::IntegerType>(resultTy.getDtype())) {
unsign = cast<IntegerType>(resultTy.getDtype()).isUnsigned();
}
auto fpFold = [](llvm::ArrayRef<double> inputs) {
assert(inputs.size() == 2);
return std::fmod(inputs[0], inputs[1]);
};

auto intFold = [unsign](llvm::ArrayRef<APInt> inputs) {
assert(inputs.size() == 2);
auto ret = unsign ? inputs[0].urem(inputs[1]) : inputs[0].srem(inputs[1]);
return ret;
};

return naryFolderHelper(adaptor.getOperands(), getType(), fpFold, intFold);
}

//===----------------------------------------------------------------------===//
// AtenAddIntOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -4229,6 +4327,42 @@ void AtenIntTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
});
}

//===----------------------------------------------------------------------===//
// AtenIntTensorOp
//===----------------------------------------------------------------------===//

OpFoldResult AtenIntTensorOp::fold(FoldAdaptor adaptor) {
auto value = adaptor.getA();
auto dense = dyn_cast_or_null<DenseElementsAttr>(value);
if (!dense || !dense.isSplat()) {
return nullptr;
}

auto splat = dense.getSplatValue<Attribute>();
if (auto intAttr = dyn_cast<IntegerAttr>(splat)) {
auto type = getType();
if (!isa<mlir::IntegerType>(type)) {
return nullptr;
}

if (type.isSignlessInteger()) {
return getI64IntegerAttr(getContext(), intAttr.getInt());
} else if (type.isSignedInteger()) {
return getI64IntegerAttr(getContext(), intAttr.getSInt());
} else {
return getI64IntegerAttr(getContext(), intAttr.getUInt());
}
}

if (auto floatAttr = dyn_cast<FloatAttr>(splat)) {
return getI64IntegerAttr(
getContext(),
static_cast<long>(floatAttr.getValue().convertToDouble()));
}

return nullptr;
}

//===----------------------------------------------------------------------===//
// AtenFloatTensorOp
//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,7 @@ def emit_with_mutating_variants(key, **kwargs):
# variants.
emit_with_mutating_variants(
"aten::div.Tensor_mode : (Tensor, Tensor, str?) -> (Tensor)",
has_folder=True,
has_canonicalizer=True,
)
emit_with_mutating_variants(
Expand Down Expand Up @@ -481,6 +482,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit("aten::xlogy.Tensor : (Tensor, Tensor) -> (Tensor)")
emit(
"aten::rsub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)",
has_folder=True,
has_canonicalizer=True,
)
emit("aten::gelu : (Tensor, str) -> (Tensor)")
Expand Down Expand Up @@ -928,7 +930,9 @@ def emit_with_mutating_variants(key, **kwargs):
emit(
"aten::tensor.float : (float, int?, Device?, bool) -> (Tensor)", has_folder=True
)
emit("aten::Int.Tensor : (Tensor) -> (int)", has_canonicalizer=True)
emit(
"aten::Int.Tensor : (Tensor) -> (int)", has_folder=True, has_canonicalizer=True
)
emit("aten::Float.Tensor : (Tensor) -> (float)", has_folder=True)
emit_with_mutating_variants("aten::dropout : (Tensor, float, bool) -> (Tensor)")
emit("aten::native_dropout : (Tensor, float, bool?) -> (Tensor, Tensor)")
Expand Down Expand Up @@ -1080,7 +1084,7 @@ def emit_with_mutating_variants(key, **kwargs):
has_canonicalizer=True,
)
emit("aten::remainder.int : (int, int) -> (int)", has_folder=True)
emit("aten::remainder.Scalar : (Tensor, Scalar) -> (Tensor)")
emit("aten::remainder.Scalar : (Tensor, Scalar) -> (Tensor)", has_folder=True)
emit("aten::remainder.Tensor : (Tensor, Tensor) -> (Tensor)")
emit("aten::add.int : (int, int) -> (int)", has_folder=True)
emit("aten::sub.int : (int, int) -> (int)", has_folder=True)
Expand Down
110 changes: 110 additions & 0 deletions test/Dialect/Torch/torch-nary-canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -141,3 +141,113 @@ func.func @fold_aten_mul_splat_float() -> !torch.vtensor<[4],f32> {
%0 = torch.aten.mul.Tensor %cst_7, %cst_11 : !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32> -> !torch.vtensor<[4],f32>
return %0 : !torch.vtensor<[4],f32>
}

// -----

// CHECK-LABEL: @fold_aten_rsub_scalar_int
func.func @fold_aten_rsub_scalar_int() -> !torch.vtensor<[4],si64> {
// CHECK: torch.vtensor.literal(dense<-4> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
%cst_2 = torch.constant.int 2
%cst_3 = torch.vtensor.literal(dense<3> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
%0 = torch.aten.rsub.Scalar %cst_3, %cst_2, %cst_2: !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[4],si64>
return %0 : !torch.vtensor<[4],si64>
}

// -----

// CHECK-LABEL: @fold_aten_rsub_scalar_float
func.func @fold_aten_rsub_scalar_float() -> !torch.vtensor<[4],f32> {
// CHECK: torch.vtensor.literal(dense<-4.000000e+00> : tensor<4xf32>) : !torch.vtensor<[4],f32>
%cst_2 = torch.constant.float 2.0
%cst_3 = torch.vtensor.literal(dense<3.0> : tensor<4xf32>) : !torch.vtensor<[4],f32>
%0 = torch.aten.rsub.Scalar %cst_3, %cst_2, %cst_2: !torch.vtensor<[4],f32>, !torch.float, !torch.float -> !torch.vtensor<[4],f32>
return %0 : !torch.vtensor<[4],f32>
}

// -----

// CHECK-LABEL: @fold_aten_remainder_scalar_int
func.func @fold_aten_remainder_scalar_int() -> !torch.vtensor<[4],si64> {
// CHECK: torch.vtensor.literal(dense<1> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
%cst_2 = torch.constant.int 2
%cst_3 = torch.vtensor.literal(dense<3> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
%0 = torch.aten.remainder.Scalar %cst_3, %cst_2 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64>
return %0 : !torch.vtensor<[4],si64>
}

// -----

// CHECK-LABEL: @fold_aten_remainder_scalar_float
func.func @fold_aten_remainder_scalar_float() -> !torch.vtensor<[4],f32> {
// CHECK: torch.vtensor.literal(dense<1.000000e+00> : tensor<4xf32>) : !torch.vtensor<[4],f32>
%cst_2 = torch.constant.float 2.0
%cst_3 = torch.vtensor.literal(dense<3.0> : tensor<4xf32>) : !torch.vtensor<[4],f32>
%0 = torch.aten.remainder.Scalar %cst_3, %cst_2 : !torch.vtensor<[4],f32>, !torch.float -> !torch.vtensor<[4],f32>
return %0 : !torch.vtensor<[4],f32>
}

// -----

// CHECK-LABEL: @fold_aten_int_tensor_int
func.func @fold_aten_int_tensor_int() -> !torch.int {
// CHECK: %int3 = torch.constant.int 3
%cst_3 = torch.vtensor.literal(dense<3> : tensor<si64>) : !torch.vtensor<[],si64>
%0 = torch.aten.Int.Tensor %cst_3 : !torch.vtensor<[],si64> -> !torch.int
return %0 : !torch.int
}

// -----

// CHECK-LABEL: @fold_aten_int_tensor_bool
func.func @fold_aten_int_tensor_bool() -> !torch.int {
// CHECK: %int1 = torch.constant.int 1
%cst_false = torch.vtensor.literal(dense<true> : tensor<i1>) : !torch.vtensor<[],i1>
%0 = torch.aten.Int.Tensor %cst_false : !torch.vtensor<[],i1> -> !torch.int
return %0 : !torch.int
}

// -----

// CHECK-LABEL: @fold_aten_int_tensor_float
func.func @fold_aten_int_tensor_float() -> !torch.int {
// CHECK: %int3 = torch.constant.int 3
%cst_3 = torch.vtensor.literal(dense<3.1> : tensor<f32>) : !torch.vtensor<[],f32>
%0 = torch.aten.Int.Tensor %cst_3 : !torch.vtensor<[],f32> -> !torch.int
return %0 : !torch.int
}

// -----

// CHECK-LABEL: @fold_aten_div_tensor_mode_int
func.func @fold_aten_div_tensor_mode_int() -> !torch.vtensor<[4],si64> {
// CHECK: torch.vtensor.literal(dense<4> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
%cst_8 = torch.vtensor.literal(dense<8> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
%cst_2 = torch.vtensor.literal(dense<2> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
%trunc = torch.constant.str "trunc"
%0 = torch.aten.div.Tensor_mode %cst_8, %cst_2, %trunc : !torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>, !torch.str -> !torch.vtensor<[4],si64>
return %0 : !torch.vtensor<[4],si64>
}

// -----

// CHECK-LABEL: @fold_aten_div_tensor_mode_float
func.func @fold_aten_div_tensor_mode_float() -> !torch.vtensor<[4],f32> {
// CHECK: torch.vtensor.literal(dense<3.000000e+00> : tensor<4xf32>) : !torch.vtensor<[4],f32>
%cst_8 = torch.vtensor.literal(dense<8.0> : tensor<4xf32>) : !torch.vtensor<[4],f32>
%cst_2 = torch.vtensor.literal(dense<2.1> : tensor<4xf32>) : !torch.vtensor<[4],f32>
%floor = torch.constant.str "floor"
%0 = torch.aten.div.Tensor_mode %cst_8, %cst_2, %floor : !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32>, !torch.str -> !torch.vtensor<[4],f32>
return %0 : !torch.vtensor<[4],f32>
}

// -----

// CHECK-LABEL: @fold_aten_div_tensor_mode_none
func.func @fold_aten_div_tensor_mode_none() -> !torch.vtensor<[4],f32> {
// CHECK: torch.vtensor.literal(dense<2.66666675> : tensor<4xf32>) : !torch.vtensor<[4],f32>
%cst_8 = torch.vtensor.literal(dense<8> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
%cst_3 = torch.vtensor.literal(dense<3> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
%none = torch.constant.none
%0 = torch.aten.div.Tensor_mode %cst_8, %cst_3, %none : !torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>, !torch.none -> !torch.vtensor<[4],f32>
return %0 : !torch.vtensor<[4],f32>
}

0 comments on commit dc7a1ff

Please sign in to comment.