Skip to content

Commit

Permalink
feat!: all metric in one task (#117)
Browse files Browse the repository at this point in the history
Signed-off-by: ThibaultFy <[email protected]>
  • Loading branch information
ThibaultFy authored May 11, 2023
1 parent 98ccdff commit 8af9734
Show file tree
Hide file tree
Showing 19 changed files with 437 additions and 269 deletions.
55 changes: 55 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,61 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed

- BREAKING: Metrics are now given as `metric_functions` and not as `metric_key`. The functions given as metric functions to test data nodes are automatically registered in a new Substra function by SubstraFL. ([#117](https://github.com/Substra/substrafl/pull/117)).
The new argument of the TestDataNode class `metric_functions` replaces the `metric_keys` one and accepts a dictionary (using the key as the identifier of the function given as value), a list of functions or directly a function if there is only one metric to compute (`function.__name__` is then used as identifier).
Installed dependencies are the `algo_dependencies` passed to `execute_experiment`, and permissions are the same as the predict function.

From a user point of view, the metric registration changes from:

```py
def accuracy(datasamples, predictions_path):
y_true = datasamples["labels"]
y_pred = np.load(predictions_path)

return accuracy_score(y_true, np.argmax(y_pred, axis=1))

metric_deps = Dependency(pypi_dependencies=["numpy==1.23.1", "scikit-learn==1.1.1"])

permissions_metric = Permissions(public=False, authorized_ids=DATA_PROVIDER_ORGS_ID)

metric_key = add_metric(
client=client,
metric_function=accuracy,
permissions=permissions_metric,
dependencies=metric_deps,
)

test_data_nodes = [
TestDataNode(
organization_id=org_id,
data_manager_key=dataset_keys[org_id],
test_data_sample_keys=[test_datasample_keys[org_id]],
metric_keys=[metric_key],
)
for org_id in DATA_PROVIDER_ORGS_ID
]
```

to:

```py
def accuracy(datasamples, predictions_path):
y_true = datasamples["labels"]
y_pred = np.load(predictions_path)

return accuracy_score(y_true, np.argmax(y_pred, axis=1))

test_data_nodes = [
TestDataNode(
organization_id=org_id,
data_manager_key=dataset_keys[org_id],
test_data_sample_keys=[test_datasample_keys[org_id]],
metric_functions={"Accuracy": accuracy},
)
for org_id in DATA_PROVIDER_ORGS_ID
]
```

- Enforce kwargs for user facing function with more than 3 parameters ([#109](https://github.com/Substra/substrafl/pull/109))
- Remove references to `composite`. Replace by `train_task`. ([#108](https://github.com/Substra/substrafl/pull/108))

Expand Down
11 changes: 0 additions & 11 deletions benchmark/camelyon/pure_substrafl/assets/Dockerfile

This file was deleted.

23 changes: 0 additions & 23 deletions benchmark/camelyon/pure_substrafl/assets/metric.py

This file was deleted.

81 changes: 18 additions & 63 deletions benchmark/camelyon/pure_substrafl/register_assets.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,23 @@
import json
import os
import tarfile
from copy import deepcopy
from pathlib import Path
from typing import List
from typing import Optional

import numpy as np
import substra
import yaml
from substra.sdk.schemas import AssetKind
from sklearn.metrics import accuracy_score
from sklearn.metrics import roc_auc_score
from substra.sdk.schemas import DataSampleSpec
from substra.sdk.schemas import DatasetSpec
from substra.sdk.schemas import FunctionInputSpec
from substra.sdk.schemas import FunctionOutputSpec
from substra.sdk.schemas import FunctionSpec
from substra.sdk.schemas import Permissions
from tqdm import tqdm

from substrafl.nodes import AggregationNode
from substrafl.nodes import TestDataNode
from substrafl.nodes import TrainDataNode
from substrafl.nodes.node import InputIdentifiers
from substrafl.nodes.node import OutputIdentifiers

CURRENT_DIRECTORY = Path(__file__).parent

Expand Down Expand Up @@ -123,7 +119,6 @@ def add_duplicated_dataset(
"train_data_sample_keys": ["766d2029-f90b-440e-8b39-2389ab04041d"]
},
...
"metric_key": "e5a99be6-0138-461a-92fe-23f685cdc9e1"
}
msp_id (str): asset_keys key where to find the registered assets for the given client
kind (str, optional): Kind of data sample to add, either train or test. Defaults to "train".
Expand Down Expand Up @@ -210,53 +205,6 @@ def get_train_data_nodes(
return train_data_nodes


def register_metric(client: substra.Client) -> str:
"""Register a default metric.
Args:
client (substra.Client): Substra client to register the metric.
Returns:
str: Substra returned key of the registered metric.
"""

metric_archive_path = ASSETS_DIRECTORY / "metric.tar.gz"

with tarfile.open(metric_archive_path, "w:gz") as tar:
tar.add(ASSETS_DIRECTORY / "Dockerfile", arcname="Dockerfile")
tar.add(ASSETS_DIRECTORY / "metric.py", arcname="metrics.py")

metric_spec = FunctionSpec(
inputs=[
FunctionInputSpec(
identifier=InputIdentifiers.datasamples,
kind=AssetKind.data_sample.value,
optional=False,
multiple=True,
),
FunctionInputSpec(
identifier=InputIdentifiers.opener, kind=AssetKind.data_manager.value, optional=False, multiple=False
),
FunctionInputSpec(
identifier=InputIdentifiers.predictions, kind=AssetKind.model.value, optional=False, multiple=False
),
],
outputs=[
FunctionOutputSpec(
identifier=OutputIdentifiers.performance, kind=AssetKind.performance.value, multiple=False
)
],
name="ROC",
description=ASSETS_DIRECTORY / "description.md",
file=metric_archive_path,
permissions=PUBLIC_PERMISSIONS,
)

metric_key = client.add_function(metric_spec)

return metric_key


def get_test_data_nodes(
clients: List[substra.Client], test_folder: Path, asset_keys: dict, nb_data_sample
) -> TestDataNode:
Expand All @@ -274,13 +222,20 @@ def get_test_data_nodes(
Returns:
TestDataNode: Substrafl test data.
"""
# only one metric is needed as permissions are public
metric_key = asset_keys.get("metric_key") or register_metric(clients[0])
asset_keys.update(
{
"metric_key": metric_key,
}
)

def auc(datasamples, predictions_path):
"""AUC"""

y_pred = np.load(predictions_path)
y_true = datasamples.y_true
return roc_auc_score(y_true, y_pred) if len(set(y_true)) > 1 else 0

def accuracy(datasamples, predictions_path):
"""Accuracy"""

y_pred = np.load(predictions_path)
y_true = datasamples.y_true
return accuracy_score(y_true, np.round(y_pred)) if len(set(y_true)) > 1 else 0

test_data_nodes = []

Expand All @@ -300,7 +255,7 @@ def get_test_data_nodes(
organization_id=msp_id,
data_manager_key=asset_keys.get(msp_id)["dataset_key"],
test_data_sample_keys=asset_keys.get(msp_id)["test_data_sample_keys"],
metric_keys=[metric_key],
metric_functions={"ROC AUC": auc, "Accuracy": accuracy},
)
)

Expand Down
12 changes: 7 additions & 5 deletions benchmark/camelyon/workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from pure_substrafl.register_assets import load_asset_keys
from pure_substrafl.register_assets import save_asset_keys
from pure_torch.strategies import basic_fed_avg
from sklearn.metrics import accuracy_score
from sklearn.metrics import roc_auc_score
from substra.sdk.models import ComputePlanStatus
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -85,9 +86,9 @@ def substrafl_fed_avg(
seed=seed, learning_rate=learning_rate, num_workers=num_workers, index_generator=index_generator, model=model
)

# Algo dependencies
# Dependencies
base = Path(__file__).parent
algo_deps = Dependency(
dependencies = Dependency(
pypi_dependencies=["torch", "numpy", "sklearn"],
local_code=[base / "common", base / "weldon_fedavg.py"],
editable_mode=False,
Expand All @@ -107,7 +108,7 @@ def substrafl_fed_avg(
evaluation_strategy=evaluation,
aggregation_node=aggregation_node,
num_rounds=n_rounds,
dependencies=algo_deps,
dependencies=dependencies,
experiment_folder=Path(__file__).resolve().parent / "benchmark_cl_experiment_folder",
)

Expand Down Expand Up @@ -242,7 +243,8 @@ def torch_fed_avg(

# Fusion, sigmoid and to numpy
y_pred = torch.sigmoid(torch.cat(y_pred)).numpy()
metric = roc_auc_score(y_true, y_pred) if len(set(y_true)) > 1 else 0
metrics.update({k: metric})
auc = roc_auc_score(y_true, y_pred) if len(set(y_true)) > 1 else 0
acc = accuracy_score(y_true, np.round(y_pred)) if len(set(y_true)) > 1 else 0
metrics.update({k: {"ROC AUC": auc, "Accuracy": acc}})

return metrics
8 changes: 8 additions & 0 deletions substrafl/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ class EmptySharedStatesError(Exception):
StrategySharedState object."""


class ExistingRegisteredMetricError(Exception):
"""A metric with the same name is already registered."""


class IncompatibleAlgoStrategyError(Exception):
"""This algo is not compatible with this strategy."""

Expand All @@ -43,6 +47,10 @@ class InvalidPathError(Exception):
"""Invalid path."""


class InvalidMetricIdentifierError(Exception):
"""A metric name or identifier cannot be a SubstraFL Outputidentifier."""


class KeyMetadataError(Exception):
"""``substrafl_version``, ``substra_version`` and ``substratools_version`` keys can't be added
to the experiment metadata."""
Expand Down
35 changes: 21 additions & 14 deletions substrafl/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def _register_operations(
evaluation_strategy: Optional[EvaluationStrategy],
dependencies: Dependency,
) -> Tuple[List[dict], Dict[RemoteStruct, OperationKey]]:
"""Register the operations in Substra: define the algorithms we need and submit them
"""Register the operations in Substra: define the functions we need and submit them
Args:
client (substra.Client): substra client
Expand All @@ -44,16 +44,17 @@ def _register_operations(
centralized strategies
evaluation_strategy (typing.Optional[EvaluationStrategy]): the evaluation strategy
if there is one dependencies
(Dependency): dependencies of the train algo
(Dependency): dependencies of the experiment
Returns:
typing.Tuple[typing.List[dict], typing.Dict[RemoteStruct, OperationKey]]:
tasks, operation_cache
"""
# `register_operations` methods from the different organizations store the id of the already registered
# algorithm so we don't add them twice
# functions so we don't add them twice
operation_cache = dict()
predict_algo_cache = dict()
predict_function_cache = dict()
test_function_cache = dict()
tasks = list()

train_data_organizations_id = {train_data_node.organization_id for train_data_node in train_data_nodes}
Expand All @@ -80,10 +81,16 @@ def _register_operations(

if evaluation_strategy is not None:
for test_data_node in evaluation_strategy.test_data_nodes:
predict_algo_cache = test_data_node.register_predict_operations(
predict_function_cache = test_data_node.register_predict_operations(
client=client,
permissions=permissions,
cache=predict_algo_cache,
cache=predict_function_cache,
dependencies=dependencies,
)
test_function_cache = test_data_node.register_test_operations(
client=client,
permissions=permissions,
cache=test_function_cache,
dependencies=dependencies,
)

Expand Down Expand Up @@ -223,10 +230,10 @@ def execute_experiment(
task_submission_batch_size: int = 500,
) -> substra.sdk.models.ComputePlan:
"""Run a complete experiment. This will train (on the `train_data_nodes`) and test (on the
`test_data_nodes`) your `algo` with the specified `strategy` `n_rounds` times and return the
`test_data_nodes`) the specified `strategy` `n_rounds` times and return the
compute plan object from the Substra platform.
In substrafl, operations are linked to each other statically before being submitted to substra.
In SubstraFL, operations are linked to each other statically before being submitted to Substra.
The execution of:
Expand All @@ -235,7 +242,7 @@ def execute_experiment(
generate the static graph of operations.
Each element necessary for those operations (Tasks and Algorithms)
Each element necessary for those operations (Tasks and Functions)
is registered to the Substra platform thanks to the specified client.
Finally, the compute plan is sent and executed.
Expand All @@ -244,15 +251,15 @@ def execute_experiment(
Args:
client (substra.Client): A substra client to interact with the Substra platform
strategy (Strategy): The strategy by which your algorithm will be executed
strategy (Strategy): The strategy that will be executed
train_data_nodes (typing.List[TrainDataNode]): List of the nodes where training on data
occurs evaluation_strategy (EvaluationStrategy, Optional): If None performance will not be measured at all.
Otherwise measuring of performance will follow the EvaluationStrategy. Defaults to None.
aggregation_node (typing.Optional[AggregationNode]): For centralized strategy, the aggregation
node, where all the shared tasks occurs else None.
num_rounds (int): The number of time your strategy will be executed
dependencies (Dependency, Optional): Dependencies of the algorithm. It must be defined from
the substrafl Dependency class. Defaults None.
dependencies (Dependency, Optional): Dependencies of the experiment. It must be defined from
the SubstraFL Dependency class. Defaults None.
experiment_folder (typing.Union[str, pathlib.Path]): path to the folder where the experiment summary is saved.
clean_models (bool): Clean the intermediary models on the Substra platform. Set it to False
if you want to download or re-use intermediary models. This causes the disk space to fill
Expand All @@ -278,7 +285,7 @@ def execute_experiment(
train_organization_ids = [train_data_node.organization_id for train_data_node in train_data_nodes]

if len(train_organization_ids) != len(set(train_organization_ids)):
raise ValueError("Training multiple algorithms on the same organization is not supported right now.")
raise ValueError("Training multiple functions on the same organization is not supported right now.")

if evaluation_strategy is not None:
_check_evaluation_strategy(evaluation_strategy, num_rounds)
Expand Down Expand Up @@ -307,7 +314,7 @@ def execute_experiment(
)

# Computation graph is created
logger.info("Registering the algorithm to Substra.")
logger.info("Registering the functions to Substra.")
tasks, operation_cache = _register_operations(
client=client,
train_data_nodes=train_data_nodes,
Expand Down
Loading

0 comments on commit 8af9734

Please sign in to comment.