diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index a4962d12abdc0..9e74b66493ead 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -832,6 +832,12 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { op, "only support padding from a list construct"); paddingIntValues = getTypeConvertedValues(rewriter, loc, getTypeConverter(), paddingIntValues); + if (paddingIntValues.size() == 1) { + for (size_t iDim = 1; iDim < numSpatialDims; iDim++) { + paddingIntValues.push_back(paddingIntValues[0]); + } + } + SmallVector outputPaddingIntValues; if (!getListConstructElements(op.getOutputPadding(), outputPaddingIntValues)) diff --git a/lib/Conversion/TorchToStablehlo/Linear.cpp b/lib/Conversion/TorchToStablehlo/Linear.cpp index b42ed7cc77227..88617f139c96c 100644 --- a/lib/Conversion/TorchToStablehlo/Linear.cpp +++ b/lib/Conversion/TorchToStablehlo/Linear.cpp @@ -750,6 +750,11 @@ class ConvertAtenConvolutionOp : public ConvertAtenOp { return rewriter.notifyMatchFailure(op, "non-const padding list unsupported"); } + if (padding.size() == 1) { + for (auto iDim = 1; iDim < inputTy.getRank() - 2; iDim++) { + padding.push_back(padding[0]); + } + } SmallVector dilation; if (!matchPattern(op.getDilation(), m_TorchListOfConstantInts(dilation))) { return rewriter.notifyMatchFailure(op, diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index e5f4fea4f46c5..7539bf49bf9dd 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -2163,6 +2163,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( m_TorchListOfConstantInts(padding_2d))) return rewriter.notifyMatchFailure(op, "non-const padding list unsupported"); + if (padding_2d.size() == 1) { + padding_2d.push_back(padding_2d[0]); + } // TOSA uses 4D padding {t, b, l, r} while Torch defines 2D padding {t, l}. // The Torch OFM computation uses 2*pad in each spatial direction, implying // the same t=b and l=r values for TOSA. diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index e7512fc89e982..179cd15e4c95c 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -29,6 +29,9 @@ "DeformConv2D_basic", "ReduceAnyDimFloatModule_basic", "UnfoldModule_basic", + # TorchScript to the backend contract fails for conv.padding specified as str + "Conv2dWithValidPaddingModule_basic", + "Conv2dWithSamePaddingModule_basic", } if torch_version_for_comparison() < version.parse("2.5.0.dev"): diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py index 3bc1760489460..25f4acddcd03d 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py @@ -191,6 +191,58 @@ def Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier( module.forward(tu.rand(5, 4, 10, 20)) +class Conv2dWithValidPaddingModule(torch.nn.Module): + def __init__(self): + super().__init__() + torch.manual_seed(0) + self.conv = torch.nn.Conv2d( + 1, 1, 1, stride=[1, 1], padding="valid", dilation=[1, 1], groups=1, bias=1 + ) + self.train(False) + + @export + @annotate_args( + [ + None, + ([1, 5, 6], torch.float32, True), + ] + ) + def forward(self, x): + return self.conv(x) + + +@register_test_case(module_factory=lambda: Conv2dWithValidPaddingModule()) +def Conv2dWithValidPaddingModule_basic(module, tu: TestUtils): + t = tu.rand(1, 5, 6) + module.forward(t) + + +class Conv2dWithSamePaddingModule(torch.nn.Module): + def __init__(self): + super().__init__() + torch.manual_seed(0) + self.conv = torch.nn.Conv2d( + 1, 1, 1, stride=[1, 1], padding="same", dilation=[1, 1], groups=1, bias=1 + ) + self.train(False) + + @export + @annotate_args( + [ + None, + ([1, 5, 6], torch.float32, True), + ] + ) + def forward(self, x): + return self.conv(x) + + +@register_test_case(module_factory=lambda: Conv2dWithSamePaddingModule()) +def Conv2dWithSamePaddingModule_basic(module, tu: TestUtils): + t = tu.rand(1, 5, 6) + module.forward(t) + + # ==============================================================================