Skip to content

Commit

Permalink
Fix test_eager_matches_sdpa_inference for XPU backend
Browse files Browse the repository at this point in the history
As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH
which is implemented on PyTorch level using aten operators and is device
agnostic with respect to implementation of each aten operator. Thus, we can
reuse CUDA (or CPU) MATH weights for XPU.

Currently XPU backendtorch.nn.attention.sdpa_kernel falls back to CPU when torch
works with XPU backend. So, cpu thresholds should be used in associated
tests.

Fixes: huggingface#34888
Signed-off-by: Dmitry Rogozhkin <[email protected]>
  • Loading branch information
dvrogozh committed Nov 25, 2024
1 parent a53aaac commit 01153bc
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 0 deletions.
6 changes: 6 additions & 0 deletions tests/models/mimi/test_modeling_mimi.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,6 +653,12 @@ def get_mean_reldiff(failcase, x, ref, atol, rtol):
if torch_device in ["cpu", "cuda"]:
atol = atols[torch_device, enable_kernels, torch_dtype]
rtol = rtols[torch_device, enable_kernels, torch_dtype]
elif torch_device == "xpu":
# As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH
# which is implemented on PyTorch level using aten operators and is
# device agnostic with respect to implementation of each aten operator.
atol = atols["cuda", False, torch_dtype]
rtol = rtols["cuda", False, torch_dtype]
else:
atol = 1e-7
rtol = 1e-4
Expand Down
12 changes: 12 additions & 0 deletions tests/models/musicgen/test_modeling_musicgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,12 @@ def get_mean_reldiff(failcase, x, ref, atol, rtol):
if torch_device in ["cpu", "cuda"]:
atol = atols[torch_device, enable_kernels, torch_dtype]
rtol = rtols[torch_device, enable_kernels, torch_dtype]
elif torch_device == "xpu":
# As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH
# which is implemented on PyTorch level using aten operators and is
# device agnostic with respect to implementation of each aten operator.
atol = atols["cuda", enable_kernels, torch_dtype]
rtol = rtols["cuda", enable_kernels, torch_dtype]
else:
atol = 1e-7
rtol = 1e-4
Expand Down Expand Up @@ -1691,6 +1697,12 @@ def get_mean_reldiff(failcase, x, ref, atol, rtol):
if torch_device in ["cpu", "cuda"]:
atol = atols[torch_device, enable_kernels, torch_dtype]
rtol = rtols[torch_device, enable_kernels, torch_dtype]
elif torch_device == "xpu":
# As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH
# which is implemented on PyTorch level using aten operators and is
# device agnostic with respect to implementation of each aten operator.
atol = atols["cuda", False, torch_dtype]
rtol = rtols["cuda", False, torch_dtype]
else:
atol = 1e-7
rtol = 1e-4
Expand Down
12 changes: 12 additions & 0 deletions tests/models/musicgen_melody/test_modeling_musicgen_melody.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,6 +637,12 @@ def get_mean_reldiff(failcase, x, ref, atol, rtol):
if torch_device in ["cpu", "cuda"]:
atol = atols[torch_device, enable_kernels, torch_dtype]
rtol = rtols[torch_device, enable_kernels, torch_dtype]
elif torch_device == "xpu":
# As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH
# which is implemented on PyTorch level using aten operators and is
# device agnostic with respect to implementation of each aten operator.
atol = atols["cuda", enable_kernels, torch_dtype]
rtol = rtols["cuda", enable_kernels, torch_dtype]
else:
atol = 1e-7
rtol = 1e-4
Expand Down Expand Up @@ -1654,6 +1660,12 @@ def get_mean_reldiff(failcase, x, ref, atol, rtol):
if torch_device in ["cpu", "cuda"]:
atol = atols[torch_device, enable_kernels, torch_dtype]
rtol = rtols[torch_device, enable_kernels, torch_dtype]
elif torch_device == "xpu":
# As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH
# which is implemented on PyTorch level using aten operators and is
# device agnostic with respect to implementation of each aten operator.
atol = atols["cuda", False, torch_dtype]
rtol = rtols["cuda", False, torch_dtype]
else:
atol = 1e-7
rtol = 1e-4
Expand Down
6 changes: 6 additions & 0 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4214,6 +4214,12 @@ def get_mean_reldiff(failcase, x, ref, atol, rtol):
if torch_device in ["cpu", "cuda"]:
atol = atols[torch_device, enable_kernels, torch_dtype]
rtol = rtols[torch_device, enable_kernels, torch_dtype]
elif torch_device == "xpu":
# As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH
# which is implemented on PyTorch level using aten operators and is
# device agnostic with respect to implementation of each aten operator.
atol = atols["cuda", False, torch_dtype]
rtol = rtols["cuda", False, torch_dtype]
else:
atol = 1e-7
rtol = 1e-4
Expand Down

0 comments on commit 01153bc

Please sign in to comment.