-
Notifications
You must be signed in to change notification settings - Fork 27.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Option to set 'non_blocking' for to(device) in BatchEncoding and BatchFeature #34883
base: main
Are you sure you want to change the base?
Option to set 'non_blocking' for to(device) in BatchEncoding and BatchFeature #34883
Conversation
… improvements. Defaults to 'false', thus no behavioral changes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @daniel-bogdoll, thanks for adding this! It looks great to me. Do you think it might be worth extending the same option to BatchFeature to ensure consistent capabilities?
Thanks @qubvel, sure thing! Which tests would I need to run to make sure modifications in the to() function of BatchFeature get tested? Just to make sure, I assume you refer to
|
Yes, I refer to this one, but not sure it's properly tested anywhere, I was able to find only |
Maybe we can do it as simple as non_blocking = kwargs.get("non_blocking", False)
...
elif isinstance(v, torch.Tensor) and device is not None:
new_data[k] = v.to(device=device, non_blocking=non_blocking)
... |
That's how I would have tried it as well. But what about this block?
Here device is derived from |
I don't think so, maybe at some moment, it is worth refactoring this method for more explicit args and kwargs. For now, we can add a note in docstring that |
@qubvel Done! Thanks for the super-fast replies, was a pleasure! Tests fail now, though: For the first one, as you stated here (#34826 (comment)), it does not seem to be related.
As the second one is a timeout issue, it also seems unrelated:
|
Option to set 'non_blocking' for to(device) operation in BatchEncoding for performance improvements. Defaults to 'false', thus no behavioral changes.
What does this PR do?
This minor PR adds the non_blocking option to the to() function.
Previous: def to(self, device: Union[str, "torch.device"]) -> "BatchEncoding":
New: def to(self, device: Union[str, "torch.device"], non_blocking: bool = False) -> "BatchEncoding":
Since non_blocking defaults to 'False', this PR does not introduce behavioral changes.
I realized, when utilizing Zero Shot Object Detection models, that it was not possible to set this option, leading to sub-optimal performance during inference.
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?