You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
fromtorchaoimportquantize_fromtorchao.quantizationimportint8_weight_onlyfromtorchimportnnimporttorchlinear=nn.Linear(1024, 1024)
quantize_(linear, int8_weight_only())
linear.cuda()
linear.compile()
linear(torch.randn(1, 1024, device="cuda"))
linear.cpu() # this will errorlinear.cuda() # this will also error
Error
Traceback (most recent call last):
File "/home/xxx/python3.10/site-packages/torch/nn/modules/module.py", line 945, in _apply
torch.utils.swap_tensors(param, param_applied)
File "/home/xxx/python3.10/site-packages/torch/utils/__init__.py", line 51, in swap_tensors
raise RuntimeError("Cannot swap t1 because it has weakref associated with it")
RuntimeError: Cannot swap t1 because it has weakref associated with it
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/xxx/debug.py", line 11, in <module>
linear.cpu()
File "/home/xxx/python3.10/site-packages/torch/nn/modules/module.py", line 1118, in cpu
return self._apply(lambda t: t.cpu())
File "/home/xxx/python3.10/site-packages/torch/nn/modules/module.py", line 949, in _apply
raise RuntimeError(
RuntimeError: _apply(): Couldn't swap Linear.weight
This seems like a problem for tensor subclass + compile in general, not limited to AQT. Even doing compile(disable=False) still has this error.
To reproduce
Error
This seems like a problem for tensor subclass + compile in general, not limited to AQT. Even doing
compile(disable=False)
still has this error.cc: @jerryzh168
torchao: 0.7.0+git26648c2c (install from source)
pytorch: tested with 2.5.0 and 2.6.0.dev20241102+cu124
The text was updated successfully, but these errors were encountered: