Skip to content

Commit

Permalink
Konsti papyrus recording (#121)
Browse files Browse the repository at this point in the history
* Write independant module for computing the NTK

* Remove NTK computation from the model
Adapt all tests

* Run black and isort

* remove ntk from huggingface flax model

* Adapt integration tests to new ntk computation

* Adapt notebook API to new ntk computation

* fix ntk computation hfmodels

* remove unnessecary imports

* Install papyrus via requirements.txt

* Create Papyrus Jax Recorder and tests

* make NTK computation return ntk in list

* Create tests for the jax ntk computation

* remove unused imports from model tests

* adapt simple training to new recorder

* adapt jax recorder to api changes of papyrus

* Adapt Examples to new api
- Computing CVs Example
- Contrastive Loss Example
- Using data recorder Example

* run black

* Adapt training strategies to new recorder
Adapt examples:
- Using Training Strategies
- ResNet Example

* Adapt tests to new recorders

* Adapt the trace opt to more flexible ntk calculation.
This allows for subsampling the ntk.

* Add explenation to the data recorders notebook

* include access token in papyrus installation

* change access token reference

* try granting access to token via environment

* exclude token from install commant

* try setting git config including access token

* include access token also in doc building

* Write NTK subsampling class.

* Make ntk calculation entire data set with inputs and targets.
Add option to set the data set keys in the init of the ntk computation

* run black

* Write class-wise ntk computation

* make loss derivative computation work and include in example

* Write ntk combinations which can be used to compute the mutual information of a system.

* Move ntk computation to analysis

* Fix bugs in jax combinations

* Complete Docstrings of papyrus jax recorder

* Write Example to compute the information neural mutual information

* remove outdated recorder

* Adapt integration test to new recorder

* Include changes of papyrus to public.

---------

Co-authored-by: knikolaou <>
  • Loading branch information
KonstiNik authored Jun 10, 2024
1 parent ac59d15 commit 30c1e70
Show file tree
Hide file tree
Showing 42 changed files with 3,378 additions and 1,458 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/doc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,11 @@ jobs:
with:
python-version: '3.11'
- name: Install dependencies
env:
super_secret: ${{ secrets.PAPYRUS_ACCESS }}
run: |
sudo apt install pandoc
git config --global url."https://${{ secrets.PAPYRUS_ACCESS }}@github".insteadOf https://github
pip install -r dev-requirements.txt
pip install -r requirements.txt
pip install .
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
git config --global url."https://${{ secrets.PAPYRUS_ACCESS }}@github".insteadOf https://github
python -m pip install --upgrade pip
pip install -r dev-requirements.txt
pip install -r requirements.txt
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,10 @@ def setup_class(cls):
cls.resnet18 = HuggingFaceFlaxModel(
resnet18,
optax.sgd(learning_rate=1e-3),
batch_size=3,
)
cls.resnet50 = HuggingFaceFlaxModel(
resnet50,
optax.sgd(learning_rate=1e-3),
batch_size=3,
)

key = random.PRNGKey(0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,10 @@
import optax
from neural_tangents import stax
from numpy import testing
from papyrus.measurements import NTK, Accuracy, Loss

import znnl as rnd
import znnl as nl
from znnl.analysis import JAXNTKComputation


class TestRecorderDeployment:
Expand All @@ -56,36 +58,58 @@ def setup_class(cls):
Prepare the class for running.
"""
# Data Generator
cls.data_generator = rnd.data.MNISTGenerator(ds_size=10)
cls.data_generator = nl.data.MNISTGenerator(ds_size=10)

# Make a network
network = stax.serial(
stax.Flatten(), stax.Dense(128), stax.Relu(), stax.Dense(10)
)

# Set the class assigned recorders
cls.train_recorder = rnd.training_recording.JaxRecorder(
loss=True, accuracy=True, update_rate=1, chunk_size=11, name="trainer"
cls.train_recorder = nl.training_recording.JaxRecorder(
storage_path=".",
name="trainer",
update_rate=1,
chunk_size=11,
measurements=[
Loss(apply_fn=nl.loss_functions.MeanPowerLoss(order=2)),
Accuracy(apply_fn=nl.accuracy_functions.LabelAccuracy()),
],
)
cls.test_recorder = rnd.training_recording.JaxRecorder(
loss=True, accuracy=True, ntk=True, update_rate=5
cls.test_recorder = nl.training_recording.JaxRecorder(
storage_path=".",
name="tester",
update_rate=5,
measurements=[
Loss(apply_fn=nl.loss_functions.MeanPowerLoss(order=2)),
Accuracy(apply_fn=nl.accuracy_functions.LabelAccuracy()),
NTK(),
],
)

# Define the model
cls.production_model = rnd.models.NTModel(
cls.production_model = nl.models.NTModel(
nt_module=network,
optimizer=optax.adam(learning_rate=0.01),
input_shape=(1, 28, 28, 1),
)

cls.train_recorder.instantiate_recorder(data_set=cls.data_generator.train_ds)
cls.test_recorder.instantiate_recorder(data_set=cls.data_generator.test_ds)
cls.train_recorder.instantiate_recorder(
data_set=cls.data_generator.train_ds,
model=cls.production_model,
ntk_computation=JAXNTKComputation(cls.production_model.ntk_apply_fn),
)
cls.test_recorder.instantiate_recorder(
data_set=cls.data_generator.test_ds,
model=cls.production_model,
ntk_computation=JAXNTKComputation(cls.production_model.ntk_apply_fn),
)

# Define training strategy
cls.training_strategy = rnd.training_strategies.SimpleTraining(
cls.training_strategy = nl.training_strategies.SimpleTraining(
model=cls.production_model,
loss_fn=rnd.loss_functions.CrossEntropyLoss(),
accuracy_fn=rnd.accuracy_functions.LabelAccuracy(),
loss_fn=nl.loss_functions.CrossEntropyLoss(),
accuracy_fn=nl.accuracy_functions.LabelAccuracy(),
recorders=[cls.train_recorder, cls.test_recorder],
)
# Train the model with the recorders
Expand All @@ -109,33 +133,44 @@ def test_private_arrays(self):
"""
Test that the recorder internally holds the correct values.
"""
assert len(self.train_recorder._loss_array) == 10
assert onp.sum(self.train_recorder._loss_array) > 0
assert len(self.train_recorder._accuracy_array) == 10
assert onp.sum(self.train_recorder._accuracy_array) > 0
assert len(self.train_recorder._results["loss"]) == 10
assert onp.sum(self.train_recorder._results["loss"]) > 0
assert len(self.train_recorder._results["accuracy"]) == 10
assert onp.sum(self.train_recorder._results["accuracy"]) > 0

assert len(self.test_recorder._loss_array) == 2
assert len(self.test_recorder._accuracy_array) == 2
assert onp.sum(self.test_recorder._loss_array) > 0
assert onp.sum(self.test_recorder._accuracy_array) > 0
assert len(self.test_recorder._results["loss"]) == 2
assert len(self.test_recorder._results["accuracy"]) == 2
assert onp.sum(self.test_recorder._results["loss"]) > 0
assert onp.sum(self.test_recorder._results["accuracy"]) > 0

def test_data_dump(self):
"""
Test that the data dumping works correctly.
"""
with tempfile.TemporaryDirectory() as directory:

new_model = copy.deepcopy(self.production_model)
train_recorder = copy.deepcopy(self.train_recorder)
train_recorder.storage_path = directory

train_recorder = nl.training_recording.JaxRecorder(
storage_path=directory,
name="trainer",
update_rate=1,
chunk_size=11,
measurements=[
Loss(apply_fn=nl.loss_functions.MeanPowerLoss(order=2)),
Accuracy(apply_fn=nl.accuracy_functions.LabelAccuracy()),
],
)
train_recorder.instantiate_recorder(
train_recorder._data_set, overwrite=True
data_set=self.data_generator.train_ds,
model=new_model,
)

# Define the training strategy
training_strategy = rnd.training_strategies.SimpleTraining(
training_strategy = nl.training_strategies.SimpleTraining(
model=new_model,
loss_fn=rnd.loss_functions.CrossEntropyLoss(),
accuracy_fn=rnd.accuracy_functions.LabelAccuracy(),
loss_fn=nl.loss_functions.CrossEntropyLoss(),
accuracy_fn=nl.accuracy_functions.LabelAccuracy(),
recorders=[train_recorder],
)

Expand All @@ -147,25 +182,28 @@ def test_data_dump(self):
epochs=20,
)

# Print all files in the directory
print(f"Files in directory: {os.listdir(directory)}")

# Check if there is data in database
with hf.File(f"{directory}/trainer.h5", "r") as db:
db_loss = onp.array(db["loss"])
db_accuracy = onp.array(db["accuracy"])

class_loss = onp.array(train_recorder._loss_array)
class_accuracy = onp.array(train_recorder._accuracy_array)
class_loss = onp.array(train_recorder._results["loss"])
class_accuracy = onp.array(train_recorder._results["accuracy"])

assert db_loss.shape == (11,)
assert class_loss.shape == (9,)
assert db_loss.shape == (11, 1)
assert class_loss.shape == (9, 1)
testing.assert_raises(
AssertionError,
testing.assert_array_equal,
db_loss.sum(),
class_loss.sum(),
)

assert db_accuracy.shape == (11,)
assert class_accuracy.shape == (9,)
assert db_accuracy.shape == (11, 1)
assert class_accuracy.shape == (9, 1)
testing.assert_raises(
AssertionError,
testing.assert_array_equal,
Expand All @@ -177,36 +215,47 @@ def test_export_function_no_db(self):
"""
Test that the reports are exported correctly.
"""
train_report = self.train_recorder.gather_recording()
test_report = self.test_recorder.gather_recording()
train_report = self.train_recorder.gather()
test_report = self.test_recorder.gather()

assert len(train_report.loss) == 10
assert onp.sum(train_report.loss) > 0
assert len(train_report.accuracy) == 10
assert onp.sum(train_report.accuracy) > 0
assert len(train_report["loss"]) == 10
assert onp.sum(train_report["loss"]) > 0
assert len(train_report["accuracy"]) == 10
assert onp.sum(train_report["accuracy"]) > 0

# Arrays should be resized now.
assert len(test_report.loss) == 2
assert onp.sum(test_report.loss) > 0
assert len(test_report.accuracy) == 2
assert onp.sum(test_report.accuracy) > 0
assert len(test_report["loss"]) == 2
assert onp.sum(test_report["loss"]) > 0
assert len(test_report["accuracy"]) == 2
assert onp.sum(test_report["accuracy"]) > 0

def test_export_function_db(self):
"""
Test that the reports are exported correctly.
"""
with tempfile.TemporaryDirectory() as directory:
new_model = copy.deepcopy(self.production_model)
train_recorder = copy.deepcopy(self.train_recorder)
train_recorder.storage_path = directory

train_recorder = nl.training_recording.JaxRecorder(
storage_path=directory,
name="trainer",
update_rate=1,
chunk_size=11,
measurements=[
Loss(apply_fn=nl.loss_functions.MeanPowerLoss(order=2)),
Accuracy(apply_fn=nl.accuracy_functions.LabelAccuracy()),
],
)
train_recorder.instantiate_recorder(
train_recorder._data_set, overwrite=True
data_set=self.data_generator.train_ds,
model=new_model,
)

# Define the training strategy
training_strategy = rnd.training_strategies.SimpleTraining(
training_strategy = nl.training_strategies.SimpleTraining(
model=new_model,
loss_fn=rnd.loss_functions.CrossEntropyLoss(),
accuracy_fn=rnd.accuracy_functions.LabelAccuracy(),
loss_fn=nl.loss_functions.CrossEntropyLoss(),
accuracy_fn=nl.accuracy_functions.LabelAccuracy(),
recorders=[train_recorder],
)

Expand All @@ -218,19 +267,8 @@ def test_export_function_db(self):
epochs=20,
)

report = train_recorder.gather_recording()
assert report.loss.shape[0] == 20
testing.assert_array_equal(report.loss[11:], train_recorder._loss_array)

def test_export_function_no_db_custom_selection(self):
"""
Test that the reports are exported correctly.
"""
# Note, NTK is not recorded, it should be caught and removed.
train_report = self.train_recorder.gather_recording(
selected_properties=["loss", "ntk"]
)

assert len(train_report.loss) == 10
assert onp.sum(train_report.loss) > 0
assert "ntk" not in list(train_report.__dict__)
report = train_recorder.gather()
assert report["loss"].shape[0] == 20
testing.assert_array_equal(
report["loss"][11:], train_recorder._results["loss"]
)
Loading

0 comments on commit 30c1e70

Please sign in to comment.