Skip to content

Commit

Permalink
Konsti fix training for flaxmodels (#97)
Browse files Browse the repository at this point in the history
* Fix the training strategy of the loss aware reservoir for flax models

Issue:
When setting the number of latest points in the lar strategy to the same size as the overall training data, a forward pass of an empty data set was computed.
This does not throw an error in stax but in flax.

- Remove the need of an empty forward pass
- extend the tests to check for the fixed problem

* apply black and isort

* merge main

---------

Co-authored-by: knikolaou <>
  • Loading branch information
KonstiNik authored Jul 25, 2023
1 parent 807d4f6 commit 244d707
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 44 deletions.
157 changes: 118 additions & 39 deletions CI/unit_tests/training_strategies/test_loss_aware_reservoir.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,15 @@

import numpy as np
import optax
from flax import linen as nn
from jax import random
from neural_tangents import stax
from numpy.testing import assert_array_equal

from znnl.accuracy_functions import AccuracyFunction
from znnl.distance_metrics import DistanceMetric
from znnl.loss_functions import MeanPowerLoss
from znnl.models import JaxModel, NTModel
from znnl.models import FlaxModel, JaxModel, NTModel
from znnl.training_recording import JaxRecorder
from znnl.training_strategies import LossAwareReservoir, RecursiveMode
from znnl.training_strategies.training_decorator import train_func
Expand Down Expand Up @@ -97,11 +98,46 @@ def train_model(
return epochs, batch_size


class FlaxArchitecture(nn.Module):
"""
Test model for the Flax tests.
"""

@nn.compact
def __call__(self, x):
x = nn.Dense(5, use_bias=True)(x)
x = nn.relu(x)
x = nn.Dense(features=1, use_bias=True)(x)
return x


class TestLossAwareReservoir:
"""
Unit test suite of the loss aware reservoir training strategy.
"""

@classmethod
def setup_class(cls):
"""
Create models and data for the tests.
"""
key1, key2 = random.split(random.PRNGKey(1), 2)
x = random.normal(key1, (10, 8))
y = random.normal(key1, (10, 1))
cls.train_ds = {"inputs": x, "targets": y}
cls.test_ds = {"inputs": x, "targets": y}

cls.nt_model = NTModel(
nt_module=stax.serial(stax.Dense(5), stax.Relu(), stax.Dense(1)),
optimizer=optax.adam(learning_rate=0.001),
input_shape=(1, 8),
)
cls.flax_model = FlaxModel(
flax_module=FlaxArchitecture(),
optimizer=optax.adam(learning_rate=0.001),
input_shape=(1, 8),
)

def test_reservoir_sorting(self):
"""
Test the sorting of the reservoir.
Expand Down Expand Up @@ -139,77 +175,120 @@ def test_reservoir_sorting(self):
selection_idx = np.argsort(np.abs(raw_x))[::-1][:4]
assert_array_equal(reservoir, selection_idx)

@classmethod
def setup_class(cls):
"""
Create data for the tests.
"""
key1, key2 = random.split(random.PRNGKey(1), 2)
x = random.normal(key1, (10, 8))
y = random.normal(key1, (10, 1))
cls.train_ds = {"inputs": x, "targets": y}
cls.test_ds = {"inputs": x, "targets": y}

def test_latest_point_exclusion(self):
def test_update_reservoir(self):
"""
Test the method _update_reservoir excludes the latest points from train_ds.
When selecting latest_points > 0, this number of points is separated from the
train data. The selected points will be appended to every batch.
This test checks if the method _update_reservoir removes the latest_points from
the data, as they cannot be part of the reservoir. The reservoir must only
consist of already seen data.
Test the method _update_reservoir.
Test whether the method excludes the latest points from train_ds.
When selecting latest_points > 0, this number of points is separated from
the train data. The selected points will be appended to every batch.
This test checks if the method _update_reservoir removes the latest_points
from the data, as they cannot be part of the reservoir. The reservoir must
only consist of already seen data.
1. For reservoir_size = len(train_ds)
* Shrinking reservoir for latest_points = 1
* Shrinking reservoir for latest_points = 4
* Shrink the reservoir to include not points for latest_points = 10
2. For reservoir_size = 5 and len(train_ds) = 10
* No shrinking reservoir size for latest_points = 4
Perform both tests for nt and flax models.
"""

model = NTModel(
nt_module=stax.serial(stax.Dense(5), stax.Relu(), stax.Dense(1)),
optimizer=optax.adam(learning_rate=0.001),
input_shape=(1, 8),
)
nt_model = self.nt_model
flax_model = self.flax_model

# Test for latest_points = 1
trainer = LossAwareReservoir(
model=model,
nt_trainer = LossAwareReservoir(
model=nt_model,
loss_fn=MeanPowerLoss(order=2),
disable_loading_bar=True,
reservoir_size=10,
latest_points=1,
)
flax_trainer = LossAwareReservoir(
model=flax_model,
loss_fn=MeanPowerLoss(order=2),
disable_loading_bar=True,
reservoir_size=10,
latest_points=1,
)

trainer.train_data_size = len(self.train_ds["inputs"])
reservoir = trainer._update_reservoir(train_ds=self.train_ds)
assert len(self.train_ds["inputs"]) - 1 == len(reservoir)
nt_trainer.train_data_size = self.train_ds["inputs"].shape[0]
flax_trainer.train_data_size = self.train_ds["inputs"].shape[0]
reservoir = nt_trainer._update_reservoir(train_ds=self.train_ds)
assert self.train_ds["inputs"].shape[0] - 1 == len(reservoir)
reservoir = flax_trainer._update_reservoir(train_ds=self.train_ds)
assert self.train_ds["inputs"].shape[0] - 1 == len(reservoir)

# Test for latest_points = 4
trainer = LossAwareReservoir(
model=model,
nt_trainer = LossAwareReservoir(
model=nt_model,
loss_fn=MeanPowerLoss(order=2),
disable_loading_bar=True,
reservoir_size=10,
latest_points=4,
)
flax_trainer = LossAwareReservoir(
model=flax_model,
loss_fn=MeanPowerLoss(order=2),
disable_loading_bar=True,
reservoir_size=10,
latest_points=4,
)

trainer.train_data_size = len(self.train_ds["inputs"])
reservoir = trainer._update_reservoir(train_ds=self.train_ds)
assert len(self.train_ds["inputs"]) - 4 == len(reservoir)
nt_trainer.train_data_size = self.train_ds["inputs"].shape[0]
flax_trainer.train_data_size = self.train_ds["inputs"].shape[0]
reservoir = nt_trainer._update_reservoir(train_ds=self.train_ds)
assert self.train_ds["inputs"].shape[0] - 4 == len(reservoir)
reservoir = flax_trainer._update_reservoir(train_ds=self.train_ds)
assert self.train_ds["inputs"].shape[0] - 4 == len(reservoir)

# Test for latest_points = 10
nt_trainer = LossAwareReservoir(
model=nt_model,
loss_fn=MeanPowerLoss(order=2),
disable_loading_bar=True,
reservoir_size=10,
latest_points=10,
)
flax_trainer = LossAwareReservoir(
model=flax_model,
loss_fn=MeanPowerLoss(order=2),
disable_loading_bar=True,
reservoir_size=10,
latest_points=10,
)

nt_trainer.train_data_size = self.train_ds["inputs"].shape[0]
flax_trainer.train_data_size = self.train_ds["inputs"].shape[0]
reservoir = nt_trainer._update_reservoir(train_ds=self.train_ds)
assert 0 == len(reservoir)
reservoir = flax_trainer._update_reservoir(train_ds=self.train_ds)
assert 0 == len(reservoir)

# Test for latest_points = 2 but for reservoir_size = 5. The reservoir size
# should not be affected now.
trainer = LossAwareReservoir(
model=model,
nt_trainer = LossAwareReservoir(
model=nt_model,
loss_fn=MeanPowerLoss(order=2),
disable_loading_bar=True,
reservoir_size=5,
latest_points=4,
)
flax_trainer = LossAwareReservoir(
model=flax_model,
loss_fn=MeanPowerLoss(order=2),
disable_loading_bar=True,
reservoir_size=5,
latest_points=4,
)

trainer.train_data_size = len(self.train_ds["inputs"])
reservoir = trainer._update_reservoir(train_ds=self.train_ds)
nt_trainer.train_data_size = self.train_ds["inputs"].shape[0]
flax_trainer.train_data_size = self.train_ds["inputs"].shape[0]
reservoir = nt_trainer._update_reservoir(train_ds=self.train_ds)
assert 5 == len(reservoir)
reservoir = flax_trainer._update_reservoir(train_ds=self.train_ds)
assert 5 == len(reservoir)

def test_update_training_kwargs(self):
Expand Down
16 changes: 11 additions & 5 deletions znnl/training_strategies/loss_aware_reservoir.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ def _update_reservoir(self, train_ds: dict) -> List[int]:
Updates the reservoir in the following steps:
* Exclude latest_points from the train_data
* Check whether the the reservoir will be empty or it can cover all data
* Compute distance of representations of the remaining training set
* Sort the training set according to the distance
Expand All @@ -179,13 +180,18 @@ def _update_reservoir(self, train_ds: dict) -> List[int]:
else:
old_data = {k: v[: -self.latest_points, ...] for k, v in train_ds.items()}

distances = self._compute_distance(old_data)

# If the reservoir no old data is available, return an empty array
if old_data["inputs"].shape[0] == 0:
data_idx = np.array([])
# Return the old train data indices if the reservoir can cover them all
if self.reservoir_size >= self.train_data_size - self.latest_points:
return np.arange(self.train_data_size - self.latest_points)
elif self.reservoir_size >= self.train_data_size - self.latest_points:
data_idx = np.arange(self.train_data_size - self.latest_points)
# If the reservoir is smaller than the train, data select data via the loss
return np.argsort(distances)[::-1][: self.reservoir_size]
else:
distances = self._compute_distance(old_data)
data_idx = np.argsort(distances)[::-1][: self.reservoir_size]

return data_idx

def _append_latest_points(self, data_idx: List[int], freq: int = 1):
"""
Expand Down

0 comments on commit 244d707

Please sign in to comment.