Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: allow immutable file upload for wandb logger #20193

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/lightning/pytorch/loggers/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,7 @@ def any_lightning_module_function_or_hook(self):
prefix: A string to put at the beginning of metric keys.
experiment: WandB experiment object. Automatically set when creating a run.
checkpoint_name: Name of the model checkpoint artifact being logged.
add_file_policy: If "mutable", copies file to tempdirectory before upload.
\**kwargs: Arguments passed to :func:`wandb.init` like `entity`, `group`, `tags`, etc.

Raises:
Expand All @@ -303,6 +304,7 @@ def __init__(
experiment: Union["Run", "RunDisabled", None] = None,
prefix: str = "",
checkpoint_name: Optional[str] = None,
add_file_policy: Literal["mutable", "immutable"] = "mutable",
**kwargs: Any,
) -> None:
if not _WANDB_AVAILABLE:
Expand All @@ -322,6 +324,7 @@ def __init__(
self._experiment = experiment
self._logged_model_time: Dict[str, float] = {}
self._checkpoint_callback: Optional[ModelCheckpoint] = None
self.add_file_policy = add_file_policy

# paths are processed as strings
if save_dir is not None:
Expand Down Expand Up @@ -671,7 +674,7 @@ def _scan_and_log_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> Non
if not self._checkpoint_name:
self._checkpoint_name = f"model-{self.experiment.id}"
artifact = wandb.Artifact(name=self._checkpoint_name, type="model", metadata=metadata)
artifact.add_file(p, name="model.ckpt")
artifact.add_file(p, name="model.ckpt", policy=self.add_file_policy)
aliases = ["latest", "best"] if p == checkpoint_callback.best_model_path else ["latest"]
self.experiment.log_artifact(artifact, aliases=aliases)
# remember logged models - timestamp needed in case filename didn't change (lastkckpt or custom name)
Expand Down
Loading