Skip to content

Commit

Permalink
refactor: modifying _update_from_checkpoint signature
Browse files Browse the repository at this point in the history
Signed-off-by: jeandut <[email protected]>
  • Loading branch information
jeandut committed Jan 19, 2024
1 parent 325a43f commit a56ce18
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 13 deletions.
18 changes: 8 additions & 10 deletions substrafl/algorithms/pytorch/torch_base_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,14 +244,14 @@ def _get_torch_device(self, use_gpu: bool) -> torch.device:
device = torch.device("cuda")
return device

def _update_from_checkpoint(self, path: Path) -> dict:
def _update_from_checkpoint(self, checkpoint: dict) -> None:
"""Load the checkpoint and update the internal state
from it.
Pop the values from the checkpoint so that we can ensure that it is empty at the
end, i.e. all the values have been used.
Args:
path (pathlib.Path): path where the checkpoint is saved
checkpoint (dict): the checkpoint is saved
Returns:
dict: checkpoint
Expand All @@ -260,13 +260,11 @@ def _update_from_checkpoint(self, path: Path) -> dict:
.. code-block:: python
def _update_from_checkpoint(self, path: Path) -> dict:
checkpoint = super()._update_from_checkpoint(path=path)
def _update_from_checkpoint(self, checkpoint: dict) -> None:
super()._update_from_checkpoint(checkpoint=checkpoint)
self._strategy_specific_variable = checkpoint.pop("strategy_specific_variable")
return checkpoint
return
"""
assert path.is_file(), f'Cannot load the model - does not exist {list(path.parent.glob("*"))}'
checkpoint = torch.load(path, map_location=self._device)
self._model.load_state_dict(checkpoint.pop("model_state_dict"))

if self._optimizer is not None:
Expand All @@ -285,8 +283,6 @@ def _update_from_checkpoint(self, path: Path) -> dict:
else:
torch.cuda.set_rng_state(checkpoint.pop("torch_rng_state").to("cpu"))

return checkpoint

def load_local_state(self, path: Path) -> "TorchAlgo":
"""Load the stateful arguments of this class.
Child classes do not need to override that function.
Expand All @@ -297,7 +293,9 @@ def load_local_state(self, path: Path) -> "TorchAlgo":
Returns:
TorchAlgo: The class with the loaded elements.
"""
checkpoint = self._update_from_checkpoint(path=path)
assert path.is_file(), f'Cannot load the model - does not exist {list(path.parent.glob("*"))}'
checkpoint = torch.load(path, map_location=self._device)
self._update_from_checkpoint(checkpoint=checkpoint)
assert len(checkpoint) == 0, f"Not all values from the checkpoint have been used: {checkpoint.keys()}"
return self

Expand Down
6 changes: 3 additions & 3 deletions substrafl/algorithms/pytorch/torch_fed_pca_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ def _get_state_to_save(self) -> dict:
)
return checkpoint

def _update_from_checkpoint(self, path: Path) -> dict:
def _update_from_checkpoint(self, checkpoint: dict) -> None:
"""Load the checkpoint and update the internal state from it.
Pop the values from the checkpoint so that we can ensure that it is empty at the
Expand All @@ -350,7 +350,7 @@ def _update_from_checkpoint(self, path: Path) -> dict:
Returns:
dict: checkpoint
"""
checkpoint = super()._update_from_checkpoint(path)
super()._update_from_checkpoint(checkpoint)
self.local_mean = checkpoint.pop("mean")
self.local_covmat = checkpoint.pop("covariance_matrix")
return checkpoint
return

0 comments on commit a56ce18

Please sign in to comment.