Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: modifying _update_from_checkpoint signature #186

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 8 additions & 10 deletions substrafl/algorithms/pytorch/torch_base_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,14 +244,14 @@ def _get_torch_device(self, use_gpu: bool) -> torch.device:
device = torch.device("cuda")
return device

def _update_from_checkpoint(self, path: Path) -> dict:
def _update_from_checkpoint(self, checkpoint: dict) -> None:
"""Load the checkpoint and update the internal state
from it.
Pop the values from the checkpoint so that we can ensure that it is empty at the
end, i.e. all the values have been used.

Args:
path (pathlib.Path): path where the checkpoint is saved
checkpoint (dict): the checkpoint is saved

Returns:
dict: checkpoint
Expand All @@ -260,13 +260,11 @@ def _update_from_checkpoint(self, path: Path) -> dict:

.. code-block:: python

def _update_from_checkpoint(self, path: Path) -> dict:
checkpoint = super()._update_from_checkpoint(path=path)
def _update_from_checkpoint(self, checkpoint: dict) -> None:
super()._update_from_checkpoint(checkpoint=checkpoint)
self._strategy_specific_variable = checkpoint.pop("strategy_specific_variable")
return checkpoint
return
"""
assert path.is_file(), f'Cannot load the model - does not exist {list(path.parent.glob("*"))}'
checkpoint = torch.load(path, map_location=self._device)
self._model.load_state_dict(checkpoint.pop("model_state_dict"))

if self._optimizer is not None:
Expand All @@ -285,8 +283,6 @@ def _update_from_checkpoint(self, path: Path) -> dict:
else:
torch.cuda.set_rng_state(checkpoint.pop("torch_rng_state").to("cpu"))

return checkpoint

def load_local_state(self, path: Path) -> "TorchAlgo":
"""Load the stateful arguments of this class.
Child classes do not need to override that function.
Expand All @@ -297,7 +293,9 @@ def load_local_state(self, path: Path) -> "TorchAlgo":
Returns:
TorchAlgo: The class with the loaded elements.
"""
checkpoint = self._update_from_checkpoint(path=path)
assert path.is_file(), f'Cannot load the model - does not exist {list(path.parent.glob("*"))}'
checkpoint = torch.load(path, map_location=self._device)
self._update_from_checkpoint(checkpoint=checkpoint)
assert len(checkpoint) == 0, f"Not all values from the checkpoint have been used: {checkpoint.keys()}"
return self

Expand Down
8 changes: 4 additions & 4 deletions substrafl/algorithms/pytorch/torch_fed_pca_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ def _get_state_to_save(self) -> dict:
)
return checkpoint

def _update_from_checkpoint(self, path: Path) -> dict:
def _update_from_checkpoint(self, checkpoint: dict) -> None:
"""Load the checkpoint and update the internal state from it.

Pop the values from the checkpoint so that we can ensure that it is empty at the
Expand All @@ -345,12 +345,12 @@ def _update_from_checkpoint(self, path: Path) -> dict:
to be popped.

Args:
path (pathlib.Path): path where the checkpoint is saved
checkpoint (dict): the checkpoint to load.

Returns:
dict: checkpoint
"""
checkpoint = super()._update_from_checkpoint(path)
super()._update_from_checkpoint(checkpoint=checkpoint)
self.local_mean = checkpoint.pop("mean")
self.local_covmat = checkpoint.pop("covariance_matrix")
return checkpoint
return
9 changes: 4 additions & 5 deletions substrafl/algorithms/pytorch/torch_scaffold_algo.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import logging
from enum import IntEnum
from pathlib import Path
from typing import Any
from typing import List
from typing import Optional
Expand Down Expand Up @@ -493,18 +492,18 @@ def _get_state_to_save(self) -> dict:
)
return local_state

def _update_from_checkpoint(self, path: Path) -> dict:
def _update_from_checkpoint(self, checkpoint: dict) -> None:
"""Load the local state from the checkpoint.

Args:
path (pathlib.Path): path where the checkpoint is saved
checkpoint (dict): the checkpoint to load.

Returns:
dict: checkpoint
"""
checkpoint = super()._update_from_checkpoint(path=path)
super()._update_from_checkpoint(checkpoint=checkpoint)
self._client_control_variate = checkpoint.pop("client_control_variate")
return checkpoint
return

def summary(self):
"""Summary of the class to be exposed in the experiment summary file
Expand Down
Loading