Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incorrect type in output of utils.pad_across_processes when input is torch.bool #3218

Open
2 of 4 tasks
mariusarvinte opened this issue Nov 4, 2024 · 2 comments · May be fixed by #3219
Open
2 of 4 tasks

Incorrect type in output of utils.pad_across_processes when input is torch.bool #3218

mariusarvinte opened this issue Nov 4, 2024 · 2 comments · May be fixed by #3219

Comments

@mariusarvinte
Copy link
Contributor

mariusarvinte commented Nov 4, 2024

System Info

- `Accelerate` version: 1.1.0
- Platform: Linux-6.8.0-45-generic-x86_64-with-glibc2.35
- `accelerate` bash location: .venv/bin/accelerate
- Python version: 3.11.10
- Numpy version: 2.1.1
- PyTorch version (GPU?): 2.4.1+cu121 (True)
- PyTorch XPU available: False
- PyTorch NPU available: False
- PyTorch MLU available: False
- PyTorch MUSA available: False
- System RAM: 755.50 GB
- GPU type: NVIDIA RTX 6000 Ada Generation
- `Accelerate` default config:
        Not found

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • One of the scripts in the examples/ folder of Accelerate or an officially supported no_trainer script in the examples folder of the transformers repo (such as run_no_trainer_glue.py)
  • My own task or dataset (give details below)

Reproduction

Running the following code using accelerate launch example.py

import torch
from accelerate import Accelerator
from accelerate.utils import pad_across_processes

accelerator = Accelerator()

process_tensor = (torch.randn(2, 100*(accelerator.process_index + 1)) > 0).to(accelerator.device)
print(f"{process_tensor.shape = }, {process_tensor.dtype = }, {accelerator.process_index = }")

padded_tensor = pad_across_processes(process_tensor, dim=1)
print(f"{padded_tensor.shape = }, {padded_tensor.dtype = }, {accelerator.process_index = }")

On a machine with at least two GPUs will output (example for three GPUs):

process_tensor.shape = torch.Size([2, 300]), process_tensor.dtype = torch.bool, accelerator.process_index = 2
process_tensor.shape = torch.Size([2, 100]), process_tensor.dtype = torch.bool, accelerator.process_index = 0
process_tensor.shape = torch.Size([2, 200]), process_tensor.dtype = torch.bool, accelerator.process_index = 1

padded_tensor.shape = torch.Size([2, 300]), padded_tensor.dtype = torch.bool, accelerator.process_index = 2
padded_tensor.shape = torch.Size([2, 300]), padded_tensor.dtype = torch.int64, accelerator.process_index = 0
padded_tensor.shape = torch.Size([2, 300]), padded_tensor.dtype = torch.int64, accelerator.process_index = 1

The padded tensors have the incorrect data type of torch.int64 and there is cross-device mismatch, which will further make downstream (e.g., gather) ops freeze and hard to debug.

Expected behavior

The output tensor should have the same dtype on all devices, and it should be the same as the input dtype

@mariusarvinte
Copy link
Contributor Author

I'll have a PR for this soon

@mariusarvinte mariusarvinte linked a pull request Nov 5, 2024 that will close this issue
5 tasks
Copy link

github-actions bot commented Dec 5, 2024

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant