Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TP comm overlap unit test]CUDA Error: misaligned address error when testing with recent cublas (or pytorch container) #1332

Open
erhoo82 opened this issue Nov 14, 2024 · 4 comments
Assignees

Comments

@erhoo82
Copy link
Collaborator

erhoo82 commented Nov 14, 2024

I get CUDA Error: misaligned address when running the tp comm overlap unit test with recent pytorch container.
I think the error comes from the cublas versions that enables nvjet.

[rank1]: Traceback (most recent call last):
[rank1]:   File "/lustre/fsw/coreai_mlperf_training/slym/module_tests/tp_overlap/te.tp_tests/tests/pytorch/distributed/run_gemm_with_overlap.py", line 922, in <module>
[rank1]:     sys.exit(_main(_parse_args()))
[rank1]:              ^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 355, in wrapper
[rank1]:     return f(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^
[rank1]:   File "/lustre/fsw/coreai_mlperf_training/slym/module_tests/tp_overlap/te.tp_tests/tests/pytorch/distributed/run_gemm_with_overlap.py", line 721, in _main
[rank1]:     all_outputs = _fp8_gemm()
[rank1]:                   ^^^^^^^^^^^
[rank1]:   File "/lustre/fsw/coreai_mlperf_training/slym/module_tests/tp_overlap/te.tp_tests/tests/pytorch/distributed/run_gemm_with_overlap.py", line 602, in _fp8_gemm
[rank1]:     return tex.fp8_gemm(
[rank1]:            ^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.12/dist-packages/transformer_engine/pytorch/cpp_extensions/gemm.py", line 180, in fp8_gemm
[rank1]:     _ = fn(*args)
[rank1]:         ^^^^^^^^^
[rank1]: RuntimeError: /workspace/TransformerEngine/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp:802 in function split_overlap_ag: CUDA Error: misaligned address
@denera
Copy link
Collaborator

denera commented Nov 14, 2024

/workspace/TransformerEngine/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp:802 is a cudaEventRecord call. It seems weird that this would trigger a misaligned address error, so I'm guessing the error actually originates from nvte_cublas_gemm just a few lines above that?

I'm not familiar with nvjet. Does cuBLAS have an environment variable that lets us at least temporarily disable this for debugging?

@erhoo82
Copy link
Collaborator Author

erhoo82 commented Nov 14, 2024

Not sure if there is a way.

I got the same error in the both below cases.

  1. Got the above error with the old container and setting LD_LIBRARY_PATH to use the recent cublas build. Here, when not using the recent cublas build, the unit test just runs fine.
  2. Got the above error with the latest pytorch container.

The model e2e job with the latest cublas build runs fine.
So, I think this is just about the unit test codes that is not working.

@denera
Copy link
Collaborator

denera commented Nov 14, 2024

Thanks for the info! 'll take a look at the unit tests as soon as I can (likely first thing next week).

@denera
Copy link
Collaborator

denera commented Dec 14, 2024

I started seeing the same misaligned address error with the new TE/JAX API in PR #1337. I wonder if these are related somehow. I will try again with an older container to see if it goes away. If so, I probably need to reach out to the cuBLAS team because it's not clear to me why the unit tests fail when e2e works.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants