Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

xpu: test_eager_matches_sdpa_inference tests fail with pytorch XPU backend #34888

Closed
dvrogozh opened this issue Nov 22, 2024 · 1 comment · Fixed by #34889
Closed

xpu: test_eager_matches_sdpa_inference tests fail with pytorch XPU backend #34888

dvrogozh opened this issue Nov 22, 2024 · 1 comment · Fixed by #34889
Labels
bug Tests Related to tests

Comments

@dvrogozh
Copy link
Contributor

dvrogozh commented Nov 22, 2024

With:

$ cat spec.py
import torch
DEVICE_NAME = 'xpu'
MANUAL_SEED_FN = torch.xpu.manual_seed
EMPTY_CACHE_FN = torch.xpu.empty_cache
DEVICE_COUNT_FN = torch.xpu.device_count

$ TRANSFORMERS_TEST_DEVICE_SPEC=spec.py python3 -m pytest --pspec tests/models -k test_eager_matches_sdpa_inference
<...>
FAILED tests/models/audio_spectrogram_transformer/test_modeling_audio_spectrogram_transformer.py::
    Here we also overwrite some of the tests of test_modeling_common.py, as AST does not use input_ids, inputs_embeds,
    attention_mask and seq_length.
    ::test_eager_matches_sdpa_inference_0_float16 - AssertionError: False is not true : padding_side=left, use_mask=False, enable_kernels=False: mean relative difference: 4.739e-05,...
FAILED tests/models/audio_spectrogram_transformer/test_modeling_audio_spectrogram_transformer.py::
    Here we also overwrite some of the tests of test_modeling_common.py, as AST does not use input_ids, inputs_embeds,
    attention_mask and seq_length.
    ::test_eager_matches_sdpa_inference_1_bfloat16 - AssertionError: False is not true : padding_side=left, use_mask=False, enable_kernels=False: mean relative difference: 5.913e-04,...
FAILED tests/models/bart/test_modeling_bart.py::BartModelTest::test_eager_matches_sdpa_inference_0_float16 - AssertionError: False is not true : padding_side=left, use_mask=False, enable_kernels=False: mean relative difference: 7.510e-06,...
FAILED tests/models/bart/test_modeling_bart.py::BartModelTest::test_eager_matches_sdpa_inference_1_bfloat16 - AssertionError: False is not true : padding_side=left, use_mask=False, enable_kernels=False: mean relative difference: 7.772e-05,...
FAILED tests/models/bart/test_modeling_bart.py::BartStandaloneDecoderModelTest::test_eager_matches_sdpa_inference_0_float16 - AssertionError: False is not true : padding_side=left, use_mask=False, enable_kernels=False: mean relative difference: 2.402e-05,...
FAILED tests/models/bart/test_modeling_bart.py::BartStandaloneDecoderModelTest::test_eager_matches_sdpa_inference_1_bfloat16 - AssertionError: False is not true : padding_side=left, use_mask=False, enable_kernels=False: mean relative difference: 3.490e-04,...
FAILED tests/models/bert/test_modeling_bert.py::BertModelTest::test_eager_matches_sdpa_inference_0_float16 - AssertionError: False is not true : padding_side=left, use_mask=False, enable_kernels=False: mean relative difference: 5.555e-05,...
FAILED tests/models/bert/test_modeling_bert.py::BertModelTest::test_eager_matches_sdpa_inference_1_bfloat16 - AssertionError: False is not true : padding_side=left, use_mask=False, enable_kernels=False: mean relative difference: 3.567e-04,...
<...>
======================= 159 failed, 89 passed, 793 skipped, 75366 deselected, 319 warnings in 74.89s (0:01:14) =======================

CC: @amyeroberts @ydshieh

dvrogozh added a commit to dvrogozh/transformers that referenced this issue Nov 22, 2024
Currently torch.nn.attention.sdpa_kernel falls back to CPU when torch
works with XPU backend. So, cpu thresholds should be used in associated
tests.

Fixes: huggingface#34888
Signed-off-by: Dmitry Rogozhkin <[email protected]>
@dvrogozh
Copy link
Contributor Author

Please, help review PR with the fix:

dvrogozh added a commit to dvrogozh/transformers that referenced this issue Nov 23, 2024
Currently torch.nn.attention.sdpa_kernel falls back to CPU when torch
works with XPU backend. So cpu thresholds should be used in associated
tests.

Fixes: huggingface#34888
Signed-off-by: Dmitry Rogozhkin <[email protected]>
dvrogozh added a commit to dvrogozh/transformers that referenced this issue Nov 23, 2024
Currently torch.nn.attention.sdpa_kernel falls back to CPU when torch
works with XPU backend. So, cpu thresholds should be used in associated
tests.

Fixes: huggingface#34888
Signed-off-by: Dmitry Rogozhkin <[email protected]>
@Rocketknight1 Rocketknight1 added Tests Related to tests bug labels Nov 25, 2024
dvrogozh added a commit to dvrogozh/transformers that referenced this issue Nov 25, 2024
As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH
which is implemented on PyTorch level using aten operators and is device
agnostic with respect to implementation of each aten operator. Thus, we can
reuse CUDA (or CPU) MATH weights for XPU.

Currently XPU backendtorch.nn.attention.sdpa_kernel falls back to CPU when torch
works with XPU backend. So, cpu thresholds should be used in associated
tests.

Fixes: huggingface#34888
Signed-off-by: Dmitry Rogozhkin <[email protected]>
dvrogozh added a commit to dvrogozh/transformers that referenced this issue Nov 26, 2024
As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH
which is implemented on PyTorch level using aten operators and is device
agnostic with respect to implementation of each aten operator. Thus, we can
reuse CUDA (or CPU) MATH weights for XPU.

Currently XPU backendtorch.nn.attention.sdpa_kernel falls back to CPU when torch
works with XPU backend. So, cpu thresholds should be used in associated
tests.

Fixes: huggingface#34888
Signed-off-by: Dmitry Rogozhkin <[email protected]>
dvrogozh added a commit to dvrogozh/transformers that referenced this issue Nov 26, 2024
As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH
which is implemented on PyTorch level using aten operators and is device
agnostic with respect to implementation of each aten operator. Thus, we can
reuse CUDA (or CPU) MATH weights for XPU.

Currently XPU backendtorch.nn.attention.sdpa_kernel falls back to CPU when torch
works with XPU backend. So, cpu thresholds should be used in associated
tests.

Fixes: huggingface#34888
Signed-off-by: Dmitry Rogozhkin <[email protected]>
dvrogozh added a commit to dvrogozh/transformers that referenced this issue Nov 26, 2024
As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH
which is implemented on PyTorch level using aten operators and is device
agnostic with respect to implementation of each aten operator. Thus, we can
reuse CUDA (or CPU) MATH weights for XPU.

Currently XPU backendtorch.nn.attention.sdpa_kernel falls back to CPU when torch
works with XPU backend. So, cpu thresholds should be used in associated
tests.

Fixes: huggingface#34888
Signed-off-by: Dmitry Rogozhkin <[email protected]>
dvrogozh added a commit to dvrogozh/transformers that referenced this issue Nov 26, 2024
As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH
which is implemented on PyTorch level using aten operators and is device
agnostic with respect to implementation of each aten operator. Thus, we can
reuse CUDA (or CPU) MATH weights for XPU.

Fixes: huggingface#34888
Signed-off-by: Dmitry Rogozhkin <[email protected]>
dvrogozh added a commit to dvrogozh/transformers that referenced this issue Nov 27, 2024
As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH
which is implemented on PyTorch level using aten operators and is device
agnostic with respect to implementation of each aten operator. Thus, we can
reuse CUDA (or CPU) MATH weights for XPU.

Fixes: huggingface#34888
Signed-off-by: Dmitry Rogozhkin <[email protected]>
ydshieh pushed a commit that referenced this issue Dec 2, 2024
* Use torch.nn.attention.sdpa_kernel instead of deprecated torch.backends.cuda.sdp_kernel

Signed-off-by: Dmitry Rogozhkin <[email protected]>

* Fix test_eager_matches_sdpa_inference for XPU backend

As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH
which is implemented on PyTorch level using aten operators and is device
agnostic with respect to implementation of each aten operator. Thus, we can
reuse CUDA (or CPU) MATH weights for XPU.

Fixes: #34888
Signed-off-by: Dmitry Rogozhkin <[email protected]>

* Use torch.amp.autocast instead of deprecated torch.cuda.amp.autocast in nemotron

Signed-off-by: Dmitry Rogozhkin <[email protected]>

---------

Signed-off-by: Dmitry Rogozhkin <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Tests Related to tests
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants