Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reimplement dynamic_propagation with meta tensors #204

Closed
jansel opened this issue May 5, 2022 · 9 comments
Closed

Reimplement dynamic_propagation with meta tensors #204

jansel opened this issue May 5, 2022 · 9 comments
Assignees
Labels
enhancement New feature or request

Comments

@jansel
Copy link
Contributor

jansel commented May 5, 2022

One obvious source of memory overhead from TorchDynamo is config.dynamic_propagation=True. With this mode, TorchDynamo will create an example_value copy of every tensor and nn.Module, in order to have accurate python type/dtype/device/shape information. This could easily double memory usage in the worst case.

This approach is nice, in that in it highly accurate and trivial to implement -- however it is very wasteful in the memory department.

We should rewrite dynamic_propagation to use meta tensors (and fall back to real tensors for ops where meta tensors aren't implemented).

It is a very possible there are other sources of memory overhead as well, I think @anijain2305 is looking into one.

Most things should work if you disable dynamic_propagation. The exceptions are it allows constant inlining of tensor properties (dtype/device/ndim/shape/contiguous/layout/etc) and handling of ops that return lists/tuples/etc.

_Originally posted by @jansel in pytorch/pytorch#93751

@jansel
Copy link
Contributor Author

jansel commented May 5, 2022

The tricky part of using meta tensors will be:

  1. working around the ops where meta tensors don't exist. Perhaps we could just materialize random tensors in these cases and free them after running the example value.
  2. we won't have device information when using meta tensors (since meta tensors are a device). I think it should be easy to emulate this by just assuming devices stay the same and building a list of ops that break that rule.

@jansel
Copy link
Contributor Author

jansel commented May 5, 2022

@eellison any interest in this one?

@Chillee
Copy link
Contributor

Chillee commented May 5, 2022

There's some work on various parts of this (might be a FB-only link): https://docs.google.com/document/d/1W1eWV5F4UEEkeVOIRUNwt68Pb_1kg_mwodBMB4Pzjac/edit?usp=sharing

we won't have device information when using meta tensors

This is pretty annoying haha - Can resolved this with something called "fake tensors".

Perhaps we could make a tensor subclass for this - I used to have a meta tensor tracing mode for __torch_dispatch__ tracing, and I did both of the things you mentioned haha. You can see the remnants deleted in this PR (https://github.com/pytorch/functorch/pull/554/files).

I think the right solution should probably leverage decompositions significantly.

cc: @ezyang

@ezyang
Copy link
Contributor

ezyang commented May 5, 2022

We are going to do fake tensors.

@ezyang
Copy link
Contributor

ezyang commented May 5, 2022

For generating random inputs, this will interact poorly for indexing ops and data dependent ops (like nonzero). So we just need full coverage there. The PoR is also to make it possible to add meta coverage in Python so it's easy to drop in a prop function as necessary.

@jansel
Copy link
Contributor Author

jansel commented May 5, 2022

Awesome, fake tensors sounds perfect!

@eellison
Copy link
Contributor

eellison commented May 5, 2022

Happy to take this on, but theres actually kind of an annoying amount of operators that might not trivially preserve the inputs device. E.g. any conversions, plus all of the annoying behavior wrt/0-dim tensors. @ezyang do we have have any timeline on fake tensors in core? depending on the timeline it might make more sense to wait for their availability

Edit: actually, just special-casing the conversion operators and re-materializing and running on the case where there are different device inputs should be fine

@ezyang
Copy link
Contributor

ezyang commented May 8, 2022

eellison signed up to make fake tensors in core happen

@chekangliang chekangliang added the enhancement New feature or request label May 16, 2022
@ezyang
Copy link
Contributor

ezyang commented Jul 20, 2022

the basic functionality works we just need to fix all the bugs

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

5 participants