Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Option to set 'non_blocking' for to(device) in BatchEncoding and BatchFeature #34883

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

daniel-bogdoll
Copy link

Option to set 'non_blocking' for to(device) operation in BatchEncoding for performance improvements. Defaults to 'false', thus no behavioral changes.

What does this PR do?

This minor PR adds the non_blocking option to the to() function.

Previous: def to(self, device: Union[str, "torch.device"]) -> "BatchEncoding":
New: def to(self, device: Union[str, "torch.device"], non_blocking: bool = False) -> "BatchEncoding":

Since non_blocking defaults to 'False', this PR does not introduce behavioral changes.

I realized, when utilizing Zero Shot Object Detection models, that it was not possible to set this option, leading to sub-optimal performance during inference.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

… improvements. Defaults to 'false', thus no behavioral changes.
Copy link
Member

@qubvel qubvel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @daniel-bogdoll, thanks for adding this! It looks great to me. Do you think it might be worth extending the same option to BatchFeature to ensure consistent capabilities?

@daniel-bogdoll
Copy link
Author

daniel-bogdoll commented Nov 22, 2024

Thanks @qubvel, sure thing! Which tests would I need to run to make sure modifications in the to() function of BatchFeature get tested?

Just to make sure, I assume you refer to

def to(self, *args, **kwargs) -> "BatchFeature":
?

@qubvel
Copy link
Member

qubvel commented Nov 22, 2024

Yes, I refer to this one, but not sure it's properly tested anywhere, I was able to find only SequenceFeatureExtractionTestMixin

@qubvel
Copy link
Member

qubvel commented Nov 22, 2024

Maybe we can do it as simple as

non_blocking = kwargs.get("non_blocking", False)
...
elif isinstance(v, torch.Tensor) and device is not None:
      new_data[k] = v.to(device=device, non_blocking=non_blocking)
...

@daniel-bogdoll
Copy link
Author

That's how I would have tried it as well. But what about this block?

# Check if the args are a device or a dtype
        if device is None and len(args) > 0:
            # device should be always the first argument
            arg = args[0]
            if is_torch_dtype(arg):
                # The first argument is a dtype
                pass
            elif isinstance(arg, str) or is_torch_device(arg) or isinstance(arg, int):
                device = arg
            else:
                # it's something else
                raise ValueError(f"Attempting to cast a BatchFeature to type {str(arg)}. This is not supported.")

Here device is derived from args rather than kwargs. Should this be extended in some way to also consider deriving non_blocking? Not sure where or how this is used.

@daniel-bogdoll daniel-bogdoll changed the title Option to set 'non_blocking' for to(device) operation in BatchEncoding Option to set 'non_blocking' for to(device) in BatchEncoding and BatchFeature Nov 22, 2024
@qubvel
Copy link
Member

qubvel commented Nov 22, 2024

Here device is derived from args rather than kwargs. Should this be extended in some way to also consider deriving non_blocking? Not sure where or how this is used.

I don't think so, maybe at some moment, it is worth refactoring this method for more explicit args and kwargs. For now, we can add a note in docstring that non_blocking should be passed as a keyword argument.

@daniel-bogdoll
Copy link
Author

daniel-bogdoll commented Nov 22, 2024

@qubvel Done! Thanks for the super-fast replies, was a pleasure! Tests fail now, though:

For the first one, as you stated here (#34826 (comment)), it does not seem to be related.

https://app.circleci.com/pipelines/github/huggingface/transformers/111324/workflows/3351b194-4b9e-4a17-876b-85360fc7ff01/jobs/1482124?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-checks-link&utm_content=summary

FAILED
tests/models/xlm_roberta_xl/test_modeling_xlm_roberta_xl.py::XLMRobertaXLModelTest::test_assisted_decoding_matches_greedy_search_1_same 
- AssertionError: False is not true

As the second one is a timeout issue, it also seems unrelated:

https://app.circleci.com/pipelines/github/huggingface/transformers/111324/workflows/3351b194-4b9e-4a17-876b-85360fc7ff01/jobs/1482127?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-checks-link&utm_content=summary

FAILED
tests/models/convbert/test_modeling_convbert.py::ConvBertModelTest::test_pipeline_fill_mask -
requests.exceptions.ReadTimeout: (ReadTimeoutError("HTTPSConnectionPool(host='huggingface.co', port=443):
Read timed out. (read timeout=10)"), '(Request ID: 04e3d1b8-11fc-4791-ba74-3d7d67a5f3f2)')

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants