Skip to content

Commit

Permalink
Add lowering for slice and selectInt (#398)
Browse files Browse the repository at this point in the history
  • Loading branch information
dan-garvey authored Dec 3, 2021
1 parent 46a2189 commit a52aded
Show file tree
Hide file tree
Showing 9 changed files with 416 additions and 15 deletions.
45 changes: 44 additions & 1 deletion e2e_testing/torchscript/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ def forward(self, lhs, rhs):
def matmul(self, lhs, rhs):
return torch.mm(lhs, rhs)

# ==============================================================================


@register_test_case(module_factory=lambda: MmTanhModule())
def MmTanhModule_basic(module, tu: TestUtils):
Expand Down Expand Up @@ -192,6 +194,8 @@ def forward(self, x):
def AdaptiveAvgPool2dModule_basic(module, tu: TestUtils):
module.forward(tu.rand(10, 3, 8, 9))

# ==============================================================================


class FlattenStaticModule(torch.nn.Module):
def __init__(self):
Expand All @@ -211,6 +215,8 @@ def forward(self, x):
def FlattenStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(10, 3, 8, 9, 3, 4))

# ==============================================================================


class FlattenRank0Module(torch.nn.Module):
def __init__(self):
Expand All @@ -230,6 +236,8 @@ def forward(self, x):
def FlattenRank0Module_basic(module, tu: TestUtils):
module.forward(torch.tensor(4.0))

# ==============================================================================


class FlattenDynamicModule(torch.nn.Module):
def __init__(self):
Expand All @@ -249,6 +257,8 @@ def forward(self, x):
def FlattenDynamicModule_basic(module, tu: TestUtils):
module.forward(tu.rand(10, 3, 8, 9, 3, 4))

# ==============================================================================


class MaxPool2dModule(torch.nn.Module):
def __init__(self):
Expand All @@ -266,6 +276,8 @@ def __init__(self):
def forward(self, x):
return self.mp2d(x)

# ==============================================================================


@register_test_case(module_factory=lambda: MaxPool2dModule())
def MaxPool2dModule_basic(module, tu: TestUtils):
Expand All @@ -284,6 +296,8 @@ def __init__(self):
def forward(self, x):
return torch.transpose(x, 0, 1)

# ==============================================================================


@register_test_case(module_factory=lambda: TransposeIntModule())
def TransposeIntModule_basic(module, tu: TestUtils):
Expand All @@ -305,6 +319,8 @@ def forward(self, x):
def PermuteModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 2))

# ==============================================================================

class TransposeIntNegDimsModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand All @@ -317,6 +333,8 @@ def __init__(self):
def forward(self, x):
return torch.transpose(x, -1, -2)

# ==============================================================================


@register_test_case(module_factory=lambda: TransposeIntNegDimsModule())
def TransposeIntNegDimsModule_basic(module, tu: TestUtils):
Expand All @@ -335,6 +353,8 @@ def __init__(self):
def forward(self, x):
return x.permute(0, -1, 1)

# ==============================================================================

@register_test_case(module_factory=lambda: PermuteNegativeIndexModule())
def PermuteNegativeIndexModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 2))
Expand All @@ -357,6 +377,8 @@ def forward(self, x, y, z):
def TensorsConcatModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 2, 4), tu.rand(2, 1, 4), tu.rand(2, 3, 4))

# ==============================================================================


class GatherModule(torch.nn.Module):
def __init__(self):
Expand All @@ -376,6 +398,8 @@ def forward(self, tensor, indices):
def GatherModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3, 4), torch.tensor([[[1, 2, 3], [1, 2, 3]]]))

# ==============================================================================

class AddSizeIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand All @@ -396,6 +420,8 @@ def forward(self, tensor):
def AddSizeIntModule_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 3))

# ==============================================================================


class AddSizeIntNegDimModule(torch.nn.Module):
def __init__(self):
Expand All @@ -417,6 +443,8 @@ def forward(self, tensor):
def AddSizeIntNegDimModule_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 3))

# ==============================================================================

class EmbeddingModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand All @@ -438,6 +466,7 @@ def forward(self, indices):
def EmbeddingModule_basic(module, tu: TestUtils):
module.forward(torch.randint(100, (3, 3)))

# ==============================================================================

class SoftmaxIntModule(torch.nn.Module):
def __init__(self):
Expand Down Expand Up @@ -474,6 +503,8 @@ def forward(self, tensor):
def _SoftmaxModule_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 2, 4))

# ==============================================================================


class SoftmaxIntNegDimModule(torch.nn.Module):
def __init__(self):
Expand All @@ -494,6 +525,8 @@ def forward(self, tensor):
def SoftmaxIntNegDimModule_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 2, 4))

# ==============================================================================


class SoftmaxIntArgTypeF64Module(torch.nn.Module):
def __init__(self):
Expand All @@ -513,6 +546,7 @@ def forward(self, tensor):
def SoftmaxIntArgTypeF64Module_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 2, 4).double())

# ==============================================================================

class BroadcastToModule(torch.nn.Module):
def __init__(self):
Expand All @@ -531,6 +565,8 @@ def forward(self, x):
def BroadcastToModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 1, 1))

# ==============================================================================

class ExpandModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand All @@ -548,6 +584,9 @@ def forward(self, x):
def ExpandModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 1, 1))

# ==============================================================================


class OnesModuleInt(torch.nn.Module):
def __init__(self):
super().__init__()
Expand All @@ -563,6 +602,8 @@ def forward(self):
def OnesModuleInt_basic(module, tu: TestUtils):
module.forward()

# ==============================================================================

class OnesModuleFloat(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -594,6 +635,7 @@ def forward(self):
def OnesModuleFalsePinMemory_basic(module, tu: TestUtils):
module.forward()

# ==============================================================================

class ContiguousModule(torch.nn.Module):
def __init__(self):
Expand All @@ -611,7 +653,7 @@ def forward(self, x):
@register_test_case(module_factory=lambda: ContiguousModule())
def ContiguousModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 1))

class TensorToInt(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -681,6 +723,7 @@ def forward(self):
def NumToTensorFloatModule_basic(module, tu: TestUtils):
module.forward()

# ==============================================================================

# This test can be removed once we have one real op returning 3 float32 tensors
class ReturnThreeTensorFloat32(torch.nn.Module):
Expand Down
1 change: 1 addition & 0 deletions e2e_testing/torchscript/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from . import view
from . import scalar
from . import squeeze
from . import slice_like

def _get_argparse():
config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'external']
Expand Down
Loading

0 comments on commit a52aded

Please sign in to comment.