Skip to content

Commit

Permalink
refactor: scaffold compatibility and docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
jeandut committed Jan 19, 2024
1 parent a56ce18 commit 329c831
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 7 deletions.
4 changes: 2 additions & 2 deletions substrafl/algorithms/pytorch/torch_fed_pca_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,12 +345,12 @@ def _update_from_checkpoint(self, checkpoint: dict) -> None:
to be popped.
Args:
path (pathlib.Path): path where the checkpoint is saved
checkpoint (dict): the checkpoint to load.
Returns:
dict: checkpoint
"""
super()._update_from_checkpoint(checkpoint)
super()._update_from_checkpoint(checkpoint=checkpoint)
self.local_mean = checkpoint.pop("mean")
self.local_covmat = checkpoint.pop("covariance_matrix")
return
9 changes: 4 additions & 5 deletions substrafl/algorithms/pytorch/torch_scaffold_algo.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import logging
from enum import IntEnum
from pathlib import Path
from typing import Any
from typing import List
from typing import Optional
Expand Down Expand Up @@ -493,18 +492,18 @@ def _get_state_to_save(self) -> dict:
)
return local_state

def _update_from_checkpoint(self, path: Path) -> dict:
def _update_from_checkpoint(self, checkpoint: dict) -> None:
"""Load the local state from the checkpoint.
Args:
path (pathlib.Path): path where the checkpoint is saved
checkpoint (dict): the checkpoint to load.
Returns:
dict: checkpoint
"""
checkpoint = super()._update_from_checkpoint(path=path)
checkpoint = super()._update_from_checkpoint(checkpoint=checkpoint)
self._client_control_variate = checkpoint.pop("client_control_variate")
return checkpoint
return

def summary(self):
"""Summary of the class to be exposed in the experiment summary file
Expand Down

0 comments on commit 329c831

Please sign in to comment.