Skip to content

Commit

Permalink
Fix test_eager_matches_sdpa_inference for XPU backend
Browse files Browse the repository at this point in the history
Currently torch.nn.attention.sdpa_kernel falls back to CPU when torch
works with XPU backend. So, cpu thresholds should be used in associated
tests.

Fixes: huggingface#34888
Signed-off-by: Dmitry Rogozhkin <[email protected]>
  • Loading branch information
dvrogozh committed Nov 23, 2024
1 parent 2c8aa44 commit 567a6a3
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 0 deletions.
8 changes: 8 additions & 0 deletions tests/models/musicgen/test_modeling_musicgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,10 @@ def get_mean_reldiff(failcase, x, ref, atol, rtol):
if torch_device in ["cpu", "cuda"]:
atol = atols[torch_device, enable_kernels, torch_dtype]
rtol = rtols[torch_device, enable_kernels, torch_dtype]
elif torch_device == "xpu":
# TODO: update once PyTorch XPU backend will support sdpa kernels
atol = atols["cpu", enable_kernels, torch_dtype]
rtol = rtols["cpu", enable_kernels, torch_dtype]
else:
atol = 1e-7
rtol = 1e-4
Expand Down Expand Up @@ -1691,6 +1695,10 @@ def get_mean_reldiff(failcase, x, ref, atol, rtol):
if torch_device in ["cpu", "cuda"]:
atol = atols[torch_device, enable_kernels, torch_dtype]
rtol = rtols[torch_device, enable_kernels, torch_dtype]
elif torch_device == "xpu":
# TODO: update once PyTorch XPU backend will support sdpa kernels
atol = atols["cpu", enable_kernels, torch_dtype]
rtol = rtols["cpu", enable_kernels, torch_dtype]
else:
atol = 1e-7
rtol = 1e-4
Expand Down
8 changes: 8 additions & 0 deletions tests/models/musicgen_melody/test_modeling_musicgen_melody.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,6 +637,10 @@ def get_mean_reldiff(failcase, x, ref, atol, rtol):
if torch_device in ["cpu", "cuda"]:
atol = atols[torch_device, enable_kernels, torch_dtype]
rtol = rtols[torch_device, enable_kernels, torch_dtype]
elif torch_device == "xpu":
# TODO: update once PyTorch XPU backend will support sdpa kernels
atol = atols["cpu", enable_kernels, torch_dtype]
rtol = rtols["cpu", enable_kernels, torch_dtype]
else:
atol = 1e-7
rtol = 1e-4
Expand Down Expand Up @@ -1654,6 +1658,10 @@ def get_mean_reldiff(failcase, x, ref, atol, rtol):
if torch_device in ["cpu", "cuda"]:
atol = atols[torch_device, enable_kernels, torch_dtype]
rtol = rtols[torch_device, enable_kernels, torch_dtype]
elif torch_device == "xpu":
# TODO: update once PyTorch XPU backend will support sdpa kernels
atol = atols["cpu", enable_kernels, torch_dtype]
rtol = rtols["cpu", enable_kernels, torch_dtype]
else:
atol = 1e-7
rtol = 1e-4
Expand Down
4 changes: 4 additions & 0 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4209,6 +4209,10 @@ def get_mean_reldiff(failcase, x, ref, atol, rtol):
if torch_device in ["cpu", "cuda"]:
atol = atols[torch_device, enable_kernels, torch_dtype]
rtol = rtols[torch_device, enable_kernels, torch_dtype]
elif torch_device == "xpu":
# TODO: update once PyTorch XPU backend will support sdpa kernels
atol = atols["cpu", enable_kernels, torch_dtype]
rtol = rtols["cpu", enable_kernels, torch_dtype]
else:
atol = 1e-7
rtol = 1e-4
Expand Down

0 comments on commit 567a6a3

Please sign in to comment.