Skip to content

Commit

Permalink
[BugFix] Fix parsing integer batch size in AOT
Browse files Browse the repository at this point in the history
ghstack-source-id: ffd60b71e6e9424b81eeabee77fb8710589f6cae
Pull Request resolved: #1004
  • Loading branch information
vmoens committed Oct 21, 2024
1 parent 1659518 commit ff94c46
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 9 deletions.
10 changes: 5 additions & 5 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -2054,7 +2054,7 @@ def _parse_batch_size(
source: T | dict | None,
batch_size: Sequence[int] | torch.Size | int | None = None,
) -> torch.Size:
ERR = "batch size was not specified when creating the TensorDict instance and it could not be retrieved from source."
ERR = "batch size {} was not specified when creating the TensorDict instance and it could not be retrieved from source."

if is_dynamo_compiling():
if isinstance(batch_size, torch.Size):
Expand All @@ -2065,22 +2065,22 @@ def _parse_batch_size(
return torch.Size(tuple(batch_size))
if batch_size is None:
return torch.Size([])
elif isinstance(batch_size, Number):
elif isinstance(batch_size, (Number, torch.SymInt)):
return torch.Size([batch_size])
elif isinstance(source, TensorDictBase):
return source.batch_size
raise ValueError()
raise ValueError(ERR.format(batch_size))

try:
return torch.Size(batch_size)
except Exception:
if batch_size is None:
return torch.Size([])
elif isinstance(batch_size, Number):
elif isinstance(batch_size, (Number, torch.SymInt)):
return torch.Size([batch_size])
elif isinstance(source, TensorDictBase):
return source.batch_size
raise ValueError(ERR)
raise ValueError(ERR.format(batch_size))

@property
def batch_dims(self) -> int:
Expand Down
47 changes: 43 additions & 4 deletions test/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,26 +801,65 @@ def call(x, td):


@pytest.mark.skipif(not _v2_5, reason="Requires PT>=2.5")
@pytest.mark.parametrize("strict", [True, False])
class TestExport:
def test_export_module(self):
def test_export_module(self, strict):
torch._dynamo.reset_code_caches()
tdm = Mod(lambda x, y: x * y, in_keys=["x", "y"], out_keys=["z"])
x = torch.randn(3)
y = torch.randn(3)
out = torch.export.export(tdm, args=(), kwargs={"x": x, "y": y})
out = torch.export.export(tdm, args=(), kwargs={"x": x, "y": y}, strict=strict)
assert (out.module()(x=x, y=y) == tdm(x=x, y=y)).all()

def test_export_seq(self):
def test_export_seq(self, strict):
torch._dynamo.reset_code_caches()
tdm = Seq(
Mod(lambda x, y: x * y, in_keys=["x", "y"], out_keys=["z"]),
Mod(lambda z, x: z + x, in_keys=["z", "x"], out_keys=["out"]),
)
x = torch.randn(3)
y = torch.randn(3)
out = torch.export.export(tdm, args=(), kwargs={"x": x, "y": y})
out = torch.export.export(tdm, args=(), kwargs={"x": x, "y": y}, strict=strict)
torch.testing.assert_close(out.module()(x=x, y=y), tdm(x=x, y=y))

@pytest.mark.parametrize(
"same_shape,dymanic_shape", [[True, True], [True, False], [False, True]]
)
def test_td_output(self, strict, same_shape, dymanic_shape):
# This will only work when the tensordict is pytree-able
class Test(torch.nn.Module):
def forward(self, x: torch.Tensor, y: torch.Tensor):
return TensorDict(
{
"x": x,
"y": y,
},
batch_size=x.shape[0],
)

test = Test()
if same_shape:
x, y = torch.zeros(5, 100), torch.zeros(5, 100)
else:
x, y = torch.zeros(2, 100), torch.zeros(2, 100)
if dymanic_shape:
kwargs = {
"dynamic_shapes": {
"x": {0: torch.export.Dim("batch"), 1: torch.export.Dim("time")},
"y": {0: torch.export.Dim("batch"), 1: torch.export.Dim("time")},
}
}
else:
kwargs = {}

result = torch.export.export(test, args=(x, y), strict=False, **kwargs)
export_mod = result.module()
x_new, y_new = torch.zeros(5, 100), torch.zeros(5, 100)
export_test = export_mod(x_new, y_new)
eager_test = test(x_new, y_new)
assert eager_test.batch_size == export_test.batch_size
assert (export_test == eager_test).all()


@pytest.mark.skipif(not _has_onnx, reason="ONNX is not available")
class TestONNXExport:
Expand Down

0 comments on commit ff94c46

Please sign in to comment.