Skip to content

Commit

Permalink
Support convolution with valid padding.
Browse files Browse the repository at this point in the history
  • Loading branch information
sahas3 committed Oct 18, 2024
1 parent dc7a1ff commit 9a9e409
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 0 deletions.
6 changes: 6 additions & 0 deletions lib/Conversion/TorchToLinalg/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -832,6 +832,12 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
op, "only support padding from a list construct");
paddingIntValues = getTypeConvertedValues(rewriter, loc, getTypeConverter(),
paddingIntValues);
if (paddingIntValues.size() == 1) {
for (size_t iDim = 1; iDim < numSpatialDims; iDim++) {
paddingIntValues.push_back(paddingIntValues[0]);
}
}

SmallVector<Value> outputPaddingIntValues;
if (!getListConstructElements(op.getOutputPadding(),
outputPaddingIntValues))
Expand Down
5 changes: 5 additions & 0 deletions lib/Conversion/TorchToStablehlo/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -750,6 +750,11 @@ class ConvertAtenConvolutionOp : public ConvertAtenOp<AtenConvolutionOp> {
return rewriter.notifyMatchFailure(op,
"non-const padding list unsupported");
}
if (padding.size() == 1) {
for (auto iDim = 1; iDim < inputTy.getRank() - 2; iDim++) {
padding.push_back(padding[0]);
}
}
SmallVector<int64_t> dilation;
if (!matchPattern(op.getDilation(), m_TorchListOfConstantInts(dilation))) {
return rewriter.notifyMatchFailure(op,
Expand Down
3 changes: 3 additions & 0 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2163,6 +2163,9 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
m_TorchListOfConstantInts(padding_2d)))
return rewriter.notifyMatchFailure(op,
"non-const padding list unsupported");
if (padding_2d.size() == 1) {
padding_2d.push_back(padding_2d[0]);
}
// TOSA uses 4D padding {t, b, l, r} while Torch defines 2D padding {t, l}.
// The Torch OFM computation uses 2*pad in each spatial direction, implying
// the same t=b and l=r values for TOSA.
Expand Down
3 changes: 3 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
"DeformConv2D_basic",
"ReduceAnyDimFloatModule_basic",
"UnfoldModule_basic",
# TorchScript to the backend contract fails for conv.padding specified as str
"Conv2dWithValidPaddingModule_basic",
"Conv2dWithSamePaddingModule_basic",
}

if torch_version_for_comparison() < version.parse("2.5.0.dev"):
Expand Down
52 changes: 52 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,58 @@ def Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier(
module.forward(tu.rand(5, 4, 10, 20))


class Conv2dWithValidPaddingModule(torch.nn.Module):
def __init__(self):
super().__init__()
torch.manual_seed(0)
self.conv = torch.nn.Conv2d(
1, 1, 1, stride=[1, 1], padding="valid", dilation=[1, 1], groups=1, bias=1
)
self.train(False)

@export
@annotate_args(
[
None,
([1, 5, 6], torch.float32, True),
]
)
def forward(self, x):
return self.conv(x)


@register_test_case(module_factory=lambda: Conv2dWithValidPaddingModule())
def Conv2dWithValidPaddingModule_basic(module, tu: TestUtils):
t = tu.rand(1, 5, 6)
module.forward(t)


class Conv2dWithSamePaddingModule(torch.nn.Module):
def __init__(self):
super().__init__()
torch.manual_seed(0)
self.conv = torch.nn.Conv2d(
1, 1, 1, stride=[1, 1], padding="same", dilation=[1, 1], groups=1, bias=1
)
self.train(False)

@export
@annotate_args(
[
None,
([1, 5, 6], torch.float32, True),
]
)
def forward(self, x):
return self.conv(x)


@register_test_case(module_factory=lambda: Conv2dWithSamePaddingModule())
def Conv2dWithSamePaddingModule_basic(module, tu: TestUtils):
t = tu.rand(1, 5, 6)
module.forward(t)


# ==============================================================================


Expand Down

0 comments on commit 9a9e409

Please sign in to comment.