-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding checkpoint_path for resume training #182
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #182 +/- ##
==========================================
- Coverage 37.95% 37.75% -0.20%
==========================================
Files 20 20
Lines 1333 1348 +15
==========================================
+ Hits 506 509 +3
- Misses 827 839 +12 ☔ View full report in Codecov by Sentry. |
Signed-off-by: nikk-nikaznan <[email protected]>
Signed-off-by: nikk-nikaznan <[email protected]>
…Centre/crabs-exploration into nikkna/resume_training
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks Nik 🙌
While reviewing this PR I found a super tricky bug - but I think I found a fix too.
the bug
Steps to reproduce:
- train a model for one epoch (specifying
n_epochs=1
in the yaml file) and save aweights_only
checkpoint.- the checkpoint is at the
path_to_checkpoints
parameter logged in MLflow (the name islast.ckpt
).
- the checkpoint is at the
- then launch a training job that starts from that checkpoint . Before I launch it, I edit the config file to have
n_epochs=3
. - In MLflow, this second training job has the same hyperparameters as the job that produced the training (so it has
n_epochs=1
etc), but in reality the job runs for as many epochs as in the yaml file. So it logsn_epochs=1
, but runs forn_epochs=3
.
This is a problem because the logged hyperparameters and the actual hyperparameters used don't match.
This does not happen if you restart training from a "full" checkpoint. In that case the .yaml parameters are logged and used.
the fix
I am still not sure why this is happening - I opened an issue to further investigate. But I found we can overwrite the hyperparameters that are logged in the weights-only checkpoint with the yaml ones, by passing the config
to load_from_checkpoint
.
lightning_model = FasterRCNN.load_from_checkpoint(
self.checkpoint_path,
config=self.config,
)
Question
If we make this changes, whenever we restart training from a ckpt ("full" or "weights-only"), the hyperparameters used will be the .yaml config ones, and not the ckpt ones.
The question is: would we want the opposite behaviour? (which we can't have now because I don't know how to properly fix this bug). That would be using the hyperparams from the ckpt when loading a ckpt. It seems convenient, but sounds like something one would forget about. Let me know what you think.
I guess this also highlights that we need more detailed tests 😢
…Centre/crabs-exploration into nikkna/resume_training
Good job for spotting the bug and thank you for the review. I don't see anything wrong with the new implementation you proposed, definitely make sense. It is only happened on the For the full one, it will always be based on the checkpoint, as it is resume training. That's why we need the full state. And, if you train for 10 epochs, and not change the yaml to make it more than 10, no training will happening. As it reached the max_epoch. Hope this makes sense, happy to chat offline. |
thanks for clarifying Nik! Some questions:
We can chat later today if it's easier yes! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks Nik! 🚀
Just one question and small suggestions
@@ -77,16 +80,16 @@ def setup_trainer(self): | |||
) | |||
|
|||
# Define checkpointing callback for trainer | |||
config = self.config.get("checkpoint_saving") | |||
if config: | |||
config_ckpt = self.config.get("checkpoint_saving") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like a different name 😁 ✨
Co-authored-by: sfmig <[email protected]> Signed-off-by: nikk-nikaznan <[email protected]>
Co-authored-by: sfmig <[email protected]> Signed-off-by: nikk-nikaznan <[email protected]>
Co-authored-by: sfmig <[email protected]> Signed-off-by: nikk-nikaznan <[email protected]>
* adding ckpt_path to fit to resume training * option to resume or fine tunning * small changes * add checkpoint option in the guide * cleaned up guide * cleaned up * tring rename the ckpt * cleaned up after rebase * some changes in the guide * run pre-commit * fixed test * parsing the config to model instance during fine-tunning * small changes on guide * changes based on the review * small changes * Update crabs/detection_tracking/train_model.py Co-authored-by: sfmig <[email protected]> Signed-off-by: nikk-nikaznan <[email protected]> * Update crabs/detection_tracking/train_model.py Co-authored-by: sfmig <[email protected]> Signed-off-by: nikk-nikaznan <[email protected]> * Update crabs/detection_tracking/train_model.py Co-authored-by: sfmig <[email protected]> Signed-off-by: nikk-nikaznan <[email protected]> * cleaned up pre-commit --------- Signed-off-by: nikk-nikaznan <[email protected]> Co-authored-by: sfmig <[email protected]>
closes #179