From 78224dc23ab5dab4c720a7c4b50fada623d7cce6 Mon Sep 17 00:00:00 2001 From: jeandut Date: Thu, 11 Jan 2024 15:50:42 +0100 Subject: [PATCH] refactor: modifying _update_from_checkpoint signature Signed-off-by: jeandut --- .../algorithms/pytorch/torch_base_algo.py | 18 ++++++++---------- .../algorithms/pytorch/torch_fed_pca_algo.py | 8 ++++---- .../algorithms/pytorch/torch_scaffold_algo.py | 9 ++++----- 3 files changed, 16 insertions(+), 19 deletions(-) diff --git a/substrafl/algorithms/pytorch/torch_base_algo.py b/substrafl/algorithms/pytorch/torch_base_algo.py index 51060753..236c38c6 100644 --- a/substrafl/algorithms/pytorch/torch_base_algo.py +++ b/substrafl/algorithms/pytorch/torch_base_algo.py @@ -244,14 +244,14 @@ def _get_torch_device(self, use_gpu: bool) -> torch.device: device = torch.device("cuda") return device - def _update_from_checkpoint(self, path: Path) -> dict: + def _update_from_checkpoint(self, checkpoint: dict) -> None: """Load the checkpoint and update the internal state from it. Pop the values from the checkpoint so that we can ensure that it is empty at the end, i.e. all the values have been used. Args: - path (pathlib.Path): path where the checkpoint is saved + checkpoint (dict): the checkpoint is saved Returns: dict: checkpoint @@ -260,13 +260,11 @@ def _update_from_checkpoint(self, path: Path) -> dict: .. code-block:: python - def _update_from_checkpoint(self, path: Path) -> dict: - checkpoint = super()._update_from_checkpoint(path=path) + def _update_from_checkpoint(self, checkpoint: dict) -> None: + super()._update_from_checkpoint(checkpoint=checkpoint) self._strategy_specific_variable = checkpoint.pop("strategy_specific_variable") - return checkpoint + return """ - assert path.is_file(), f'Cannot load the model - does not exist {list(path.parent.glob("*"))}' - checkpoint = torch.load(path, map_location=self._device) self._model.load_state_dict(checkpoint.pop("model_state_dict")) if self._optimizer is not None: @@ -285,8 +283,6 @@ def _update_from_checkpoint(self, path: Path) -> dict: else: torch.cuda.set_rng_state(checkpoint.pop("torch_rng_state").to("cpu")) - return checkpoint - def load_local_state(self, path: Path) -> "TorchAlgo": """Load the stateful arguments of this class. Child classes do not need to override that function. @@ -297,7 +293,9 @@ def load_local_state(self, path: Path) -> "TorchAlgo": Returns: TorchAlgo: The class with the loaded elements. """ - checkpoint = self._update_from_checkpoint(path=path) + assert path.is_file(), f'Cannot load the model - does not exist {list(path.parent.glob("*"))}' + checkpoint = torch.load(path, map_location=self._device) + self._update_from_checkpoint(checkpoint=checkpoint) assert len(checkpoint) == 0, f"Not all values from the checkpoint have been used: {checkpoint.keys()}" return self diff --git a/substrafl/algorithms/pytorch/torch_fed_pca_algo.py b/substrafl/algorithms/pytorch/torch_fed_pca_algo.py index 12526754..6fbe2c4f 100644 --- a/substrafl/algorithms/pytorch/torch_fed_pca_algo.py +++ b/substrafl/algorithms/pytorch/torch_fed_pca_algo.py @@ -336,7 +336,7 @@ def _get_state_to_save(self) -> dict: ) return checkpoint - def _update_from_checkpoint(self, path: Path) -> dict: + def _update_from_checkpoint(self, checkpoint: dict) -> None: """Load the checkpoint and update the internal state from it. Pop the values from the checkpoint so that we can ensure that it is empty at the @@ -345,12 +345,12 @@ def _update_from_checkpoint(self, path: Path) -> dict: to be popped. Args: - path (pathlib.Path): path where the checkpoint is saved + checkpoint (dict): the checkpoint to load. Returns: dict: checkpoint """ - checkpoint = super()._update_from_checkpoint(path) + super()._update_from_checkpoint(checkpoint=checkpoint) self.local_mean = checkpoint.pop("mean") self.local_covmat = checkpoint.pop("covariance_matrix") - return checkpoint + return diff --git a/substrafl/algorithms/pytorch/torch_scaffold_algo.py b/substrafl/algorithms/pytorch/torch_scaffold_algo.py index a7f589f0..9c3695f2 100644 --- a/substrafl/algorithms/pytorch/torch_scaffold_algo.py +++ b/substrafl/algorithms/pytorch/torch_scaffold_algo.py @@ -1,6 +1,5 @@ import logging from enum import IntEnum -from pathlib import Path from typing import Any from typing import List from typing import Optional @@ -493,18 +492,18 @@ def _get_state_to_save(self) -> dict: ) return local_state - def _update_from_checkpoint(self, path: Path) -> dict: + def _update_from_checkpoint(self, checkpoint: dict) -> None: """Load the local state from the checkpoint. Args: - path (pathlib.Path): path where the checkpoint is saved + checkpoint (dict): the checkpoint to load. Returns: dict: checkpoint """ - checkpoint = super()._update_from_checkpoint(path=path) + super()._update_from_checkpoint(checkpoint=checkpoint) self._client_control_variate = checkpoint.pop("client_control_variate") - return checkpoint + return def summary(self): """Summary of the class to be exposed in the experiment summary file