-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adam optimizer is slower after loading model from checkpoint #19955
Comments
Hey @radomirgr
These links might have pointed to an earlier version but now they don't seem to show the place that you meant. Could you show me where in the PyTorch code this assumption is made? I don't remember exactly why we needed the |
Here are screen screenshots: optimizer_to_device is needed as torch don't have .to(device) method and you need to put optimizer state in the gpu. There is an issue for that here: pytorch/pytorch#8741 It might be maybe solved if you add |
PyTorch intentionally places the scalar Tensors on CPU unless compile/capturable is needed for performance reasons. Executing Python math is faster and more precise than calling into a kernel, and here we want the calculations with step to be fast. Is there a reason lightning moves everything to GPU? |
I can confirm this issue. What happens during a checkpoint is that the optimizer param state is stored (including CPU or GPU location). But then, when lightning reloads the param it forces everything onto the GPU:
This causes a problem because the Adam optimizer explicitly expects 'step' to be on the cpu @janeyx99 :
When I run the above example code (to resume after a checkpoint) under nvidia nsight I can see that it forces many copies of step from the GPU to the CPU where the algorithm expects it:
I see a total of 4094 copies from the device to the host. In contrast, if after a checkpoint restore we leave 'step' on the CPU we get only 74 copies:
This large number of transfers doesn't take a long time if you have a monopoly on the device. But, if you are sharing a device all these transfers can be a bottleneck. (These copies are forcing stream synchronization events). Tracing via tensorboard the underlying operation that is forcing this transfer is Basically, the pytorch lightning logic that blindly forces params onto the device is incorrect. Different algorithms may have different needs. Essentially, pytorch lightning is messing with the internal state of the model and making incorrect assumptions. |
One idea for a fix would be to add special handling based on the optimizer class, but it's a bit ugly.
With:
A better idea could be to push for optimizers to have a 'to' method to map them onto the device. This has been discussed in torch before along with how awkward it is to map optimizers to devices but the request doesn't seem to have much traction. Maybe there is a way to copy construct the optimizer and get the correct device assignments? But, I don't see how to do it. A third idea could be to look at the params dictionary and see whether the tensor was on the CPU or GPU, but I think that would get very flaky for remappings. E.g. you might start training on a CPU but then resume on a GPU. As an aside, @radomirgr , a third solution for the Adam optimizer might be to use the Adam parameter fused=True. Then it expects all the params to be on the GPU. In theory I think this idea could work, but when I tried it I still saw a bunch of forced copies from the GPU to CPU and I'm not sure why. |
I'm coming into this naively, but it looks like an equivalent to the
Is there a reason the above would not be viable? Tangentially, using fused=True would bypass this problem as it expects the step to be on CUDA, so @corwinjoy I am surprised to find that there are still forced copies from GPU to CPU. Are you on the latest torch nightly or an older version? Maybe these syncs have to do with the LRScheduler/lr. |
@janeyx99 So, as I understand it, the reason for the function In addition, I also agree with you that fused=True should bypass this problem, but it doesn't in the version of torch I am using. Here I am using the most recent from PyPI, torch==2.3.1. I'm not quite sure why the extra copies are happening since tensorboard stack generation seems to be broken in the latest version of Torch so I am not quite sure how to trace it. Anyway, so that's why |
Yes, I understand the need to load on distinct devices, but my code snippet should still work for that. As long as one creates an optimizer referencing parameters that are on the desired device (CUDA1 or CUDA or even CPU), It feels that doing both a checkpoint + then a move is redundant. For the fused=True still having copies--once you get more details, please feel free to open an issue in pytorch/pytorch! |
@janeyx99 Thanks! That's actually an interesting idea. I think my caveat here is that we cannot create the optimizer directly since we (generically) have only the base Optimizer class (and the detailed class is loaded via pickle). But I think we could use your idea (something) like this:
What do you think? Unless you had some other way to do this? If so, I would like to see that code since I don't understand how that would work generically. |
OK. Doing further testing, unfortunately, the idea of using |
Hm, maybe I am not understanding the use case correctly. I thought the Here is an explicit way to rewrite the
So the above should work with any optimizer generically, but it is very roundabout because it is confusing to me why there is an optimizer input with mismatching state in the first place. Instead, what I would expect in a use case is for the optimizer to be correctly loaded during checkpointing through |
@janeyx99 I'm still a bit new to all this, but here is what I see in the stack trace when debugging a restore from checkpoint (as per the above code). You have to look at the second call to
Looking at the tensors from the checkpoint they do have the right locations before Also, knowing that |
Hey everyone
|
Ah, thanks @awaelchli and @corwinjoy for the context. I see the original problem this function sought to solve was that the model parameters shifted under a created optimizer, causing the mismatch in devices for parameter and optimizer state. Here, the solution should not be to move the optimizer, but to wait til the model has been moved to its final location and then to create the optimizer. If that's not possible, reloading the state dict into a new optimizer with the final parameters would also work. I would suggest the cleaner solution of maintaining the invariant that the optimizer should be created after the model is done being modified, to ensure that the latest parameters are what get optimized. Without this invariant, it's easy to get into a wild goose chase of problems like this that crop up due to mismatch. @corwinjoy The reason it works is because load_state_dict will move state to match the parameter that is passed into the optimizer--there is already code in there to cast/move state appropriately for each optimizer, so the work should not need to be duplicated. Feel free to follow up if you have more questions--I am increasingly convinced that the spot-solution of patching the function for this issue is at best only a temporary one. |
In order to move the discussion forward, I have created a PR where this function is simply disabled to see what tests fail. It is at #20036 Before we move forward, I believe we should agree what the behavior here should be. I think that the test for the function in @janeyx99
But, I would be happy to be proven wrong. If there is a clean way to use load_state_dict I would like to see it. But, I think it will be tricky. Maybe use the model as a prototype? |
I don't understand point 2 -> why would using the latest state not work correctly/be different from the current implementation? when one reloads from checkpointing, one has to start with a fresh optimizer instance, no? I still think the strongest sturdiest solution to push for is to ensure that the optimizer state is generated with the latest parameters. |
Because the optimizer doesn't just hold a single set of parameters. Instead, it holds an array of parameters indicating model parameters that were tried. So, the optimizer restore holds an array of parameters that need to be converted. |
Is there any reason the optimizer needs to hold these parameters for some time before moving them? Regardless the parameters are only referenced in the optimizer as the optimizer object only holds Tensors for the state. In this use case, it looks like the old state on the checkpointed device is never used, so it should really never be created. Instead, my understanding is that state should only be created for the parameters that matter, which would be the latest set of parameters. Regardless, I think your approach of removing all accesses to the _optimizer_to_device function is a good place to start--then we can talk about actual problems we don't expect in the code. |
OK. To answer these questions I have submitted the following PR with an improved test for
I'm hoping this code makes it clearer what is going on and where the issue lies.
Also see the docs for |
We should have them, it's just probably hard to find them haha. But we can also add such a test. So you'd suggest not removing the function? That's ok with me if the PyTorch's |
I commented on the PR above as well, but here's how we test this use case in PyTorch, in case it helps: https://github.com/pytorch/pytorch/blob/main/test/test_optim.py#L1545-L1574
Here is how I imagine checkpointing should go:
Could you link me to the specific part that is broken? is it a test failure or something in the code? By the way, we have fused adam, adamw, sgd, and adagrad now on CPU! So that could be related.
Ah yes, thank you for clarifying that by history you mean the param state! I agree! This is precisely what load_state_dict should handle. |
@corwinjoy @awaelchli I still believe load_state_dict should be sufficient for this use case. I've tried addressing the concerns above--please point me to where specialization is needed beyond the use case I delineated above. Thank you both for the detailed discussion. |
@awaelchli @janeyx99 OK. I have done further investigation and added additional comments + tests to #20062
But, just eliminating this function does create a couple potential problems that I am not sure I understand and would like a review on.
Anyway, I guess we could replace this function with a no-op with some comments explaining the behavior because I think it is rather non-obvious. Here is what I have in the PR right now for
|
Eventually, maybe we can rip out |
There was a desire in the past to have trainer leave no memory behind and shut down cleanly. Such that other workflows after fit() could use all memory. I think we can still keep that separate from what you are fixing. I'm ok with the plan of first adding the step-specific fix, and then prefer also a removal of this long-term. A big no from me about leaving the |
I merged the fix and opened #20165 so we can work on removing the function in the future. |
@awaelchli - thanks so much for the improved and very nice tests! I think this helps clarify the behavior we want. Also, thanks for merging the interim fix so we can see improved performance as we work to removing the function. Also FYI @radomirgr . |
Bug description
When i was resuming my model from training from checkpoint i notice slowness in gpu utilization. I have found problem that adam is doing cuda sync after restoring from checkpoint. It is a problem if you have a lot of optimziers in your network.
Adam implementation is assuming that step component of the state is a cpu tensor. It is assumed here which is executed in adam here
Problem is that lightning is putting all optimizer state to the gpu here
My current workaround is:
What version are you seeing the problem on?
v2.2
How to reproduce the bug
Error messages and logs
below some nsys traces
Environment
Current environment
More info
No response
cc @Borda
The text was updated successfully, but these errors were encountered: