From 11bb3e3ae70c145f686deacd4ce573868432f03e Mon Sep 17 00:00:00 2001 From: Sayan Saha Date: Fri, 18 Oct 2024 11:28:38 -0400 Subject: [PATCH] Support convolution with `valid` padding. --- lib/Conversion/TorchToLinalg/Linear.cpp | 6 +++ lib/Conversion/TorchToStablehlo/Linear.cpp | 5 ++ lib/Conversion/TorchToTosa/TorchToTosa.cpp | 3 ++ projects/pt1/e2e_testing/xfail_sets.py | 3 ++ .../torch_mlir_e2e_test/test_suite/conv.py | 52 +++++++++++++++++++ 5 files changed, 69 insertions(+) diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index 9c914690bbf4..e6f9b81b8436 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -832,6 +832,12 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { op, "only support padding from a list construct"); paddingIntValues = getTypeConvertedValues(rewriter, loc, getTypeConverter(), paddingIntValues); + if (paddingIntValues.size() == 1) { + for (size_t iDim = 1; iDim < numSpatialDims; iDim++) { + paddingIntValues.push_back(paddingIntValues[0]); + } + } + SmallVector outputPaddingIntValues; if (!getListConstructElements(op.getOutputPadding(), outputPaddingIntValues)) diff --git a/lib/Conversion/TorchToStablehlo/Linear.cpp b/lib/Conversion/TorchToStablehlo/Linear.cpp index b42ed7cc7722..88617f139c96 100644 --- a/lib/Conversion/TorchToStablehlo/Linear.cpp +++ b/lib/Conversion/TorchToStablehlo/Linear.cpp @@ -750,6 +750,11 @@ class ConvertAtenConvolutionOp : public ConvertAtenOp { return rewriter.notifyMatchFailure(op, "non-const padding list unsupported"); } + if (padding.size() == 1) { + for (auto iDim = 1; iDim < inputTy.getRank() - 2; iDim++) { + padding.push_back(padding[0]); + } + } SmallVector dilation; if (!matchPattern(op.getDilation(), m_TorchListOfConstantInts(dilation))) { return rewriter.notifyMatchFailure(op, diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 10f6ecb357fe..766347812c75 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -2178,6 +2178,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( m_TorchListOfConstantInts(padding_2d))) return rewriter.notifyMatchFailure(op, "non-const padding list unsupported"); + if (padding_2d.size() == 1) { + padding_2d.push_back(padding_2d[0]); + } // TOSA uses 4D padding {t, b, l, r} while Torch defines 2D padding {t, l}. // The Torch OFM computation uses 2*pad in each spatial direction, implying // the same t=b and l=r values for TOSA. diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 226120302f53..443c1879ec67 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -33,6 +33,9 @@ # if a dimension is specified in all expand lists, and not in sumdim list. # This is a bug in the implementation of _trilinear in PyTorch. "Aten_TrilinearModuleZerodDimBug_basic", + # TorchScript to the backend contract fails for conv.padding specified as str + "Conv2dWithValidPaddingModule_basic", + "Conv2dWithSamePaddingModule_basic", } if torch_version_for_comparison() < version.parse("2.5.0.dev"): diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py index e6332579d575..147885b442c7 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py @@ -191,6 +191,58 @@ def Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier( module.forward(tu.rand(5, 4, 10, 20)) +class Conv2dWithValidPaddingModule(torch.nn.Module): + def __init__(self): + super().__init__() + torch.manual_seed(0) + self.conv = torch.nn.Conv2d( + 1, 1, 1, stride=[1, 1], padding="valid", dilation=[1, 1], groups=1, bias=1 + ) + self.train(False) + + @export + @annotate_args( + [ + None, + ([1, 5, 6], torch.float32, True), + ] + ) + def forward(self, x): + return self.conv(x) + + +@register_test_case(module_factory=lambda: Conv2dWithValidPaddingModule()) +def Conv2dWithValidPaddingModule_basic(module, tu: TestUtils): + t = tu.rand(1, 5, 6) + module.forward(t) + + +class Conv2dWithSamePaddingModule(torch.nn.Module): + def __init__(self): + super().__init__() + torch.manual_seed(0) + self.conv = torch.nn.Conv2d( + 1, 1, 1, stride=[1, 1], padding="same", dilation=[1, 1], groups=1, bias=1 + ) + self.train(False) + + @export + @annotate_args( + [ + None, + ([1, 5, 6], torch.float32, True), + ] + ) + def forward(self, x): + return self.conv(x) + + +@register_test_case(module_factory=lambda: Conv2dWithSamePaddingModule()) +def Conv2dWithSamePaddingModule_basic(module, tu: TestUtils): + t = tu.rand(1, 5, 6) + module.forward(t) + + # ==============================================================================