Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.compile(sync_float8_amax_and_scale_history) not working with triton latest main #1311

Open
goldhuang opened this issue Nov 19, 2024 · 2 comments

Comments

@goldhuang
Copy link

[rank0]:   File "/opt/venv/lib/python3.10/site-packages/lightning_fabric/wrappers.py", line 411, in _capture
[rank0]:     return compile_fn(*args, **kwargs)
[rank0]:   File "/opt/venv/lib/python3.10/site-packages/torch/__init__.py", line 2448, in compile
[rank0]:     return torch._dynamo.optimize(
[rank0]:   File "/opt/venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 716, in optimize
[rank0]:     return _optimize(rebuild_ctx, *args, **kwargs)
[rank0]:   File "/opt/venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 790, in _optimize
[rank0]:     compiler_config=backend.get_compiler_config()
[rank0]:   File "/opt/venv/lib/python3.10/site-packages/torch/__init__.py", line 2238, in get_compiler_config
[rank0]:     from torch._inductor.compile_fx import get_patched_config_dict
[rank0]:   File "/opt/venv/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 49, in <module>
[rank0]:     from torch._inductor.debug import save_args_for_compile_fx_inner
[rank0]:   File "/opt/venv/lib/python3.10/site-packages/torch/_inductor/debug.py", line 26, in <module>
[rank0]:     from . import config, ir  # noqa: F811, this is needed
[rank0]:   File "/opt/venv/lib/python3.10/site-packages/torch/_inductor/ir.py", line 77, in <module>
[rank0]:     from .runtime.hints import ReductionHint
[rank0]:   File "/opt/venv/lib/python3.10/site-packages/torch/_inductor/runtime/hints.py", line 36, in <module>
[rank0]:     attr_desc_fields = {f.name for f in fields(AttrsDescriptor)}
[rank0]:   File "/usr/lib/python3.10/dataclasses.py", line 1198, in fields
[rank0]:     raise TypeError('must be called with a dataclass type or instance') from None
[rank0]: TypeError: must be called with a dataclass type or instance
@vkuzo
Copy link
Contributor

vkuzo commented Nov 19, 2024

hi @goldhuang , could you please share some more information on how to reproduce this?

@goldhuang
Copy link
Author

from torchao.float8 import (
        CastConfig,
        Float8LinearConfig,
        ScalingType,
        convert_to_float8_training,
        precompute_float8_dynamic_scale_for_fsdp,
        sync_float8_amax_and_scale_history
    )
        if torchao_scale_type == "delayed":
            scaling_type_input = ScalingType.DELAYED
            scaling_type_weight = ScalingType.DELAYED
            scaling_type_grad_output = ScalingType.DELAYED
        else:
            scaling_type_input = ScalingType.DYNAMIC
            scaling_type_weight = ScalingType.DYNAMIC
            scaling_type_grad_output = ScalingType.DYNAMIC

        convert_to_float8_training(model,
                                    config=Float8LinearConfig(
                                        enable_fsdp_float8_all_gather=torchao_enable_float8_all_gather,
                                        cast_config_input=CastConfig(scaling_type=scaling_type_input),
                                        cast_config_weight=CastConfig(scaling_type=scaling_type_weight),
                                        cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output),
                                        enable_pre_and_post_forward=False,
                                        enable_amax_init=False,
                                        pad_inner_dim=True,
                                    ),
                                    module_filter_fn=model.torchao_module_filter_fn,
                                    )

        if torchao_scale_type == "delayed":
            _sync_float8_amax_and_scale_history = torch.compile(sync_float8_amax_and_scale_history)
        else:
            _precompute_float8_dynamic_scale_for_fsdp = precompute_float8_dynamic_scale_for_fsdp

_sync_float8_amax_and_scale_history = torch.compile(sync_float8_amax_and_scale_history) will run into the error with torch2.5.0+cu124 and latest triton main branch

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants