From 30c1e705c1ca2b07f32682a34431523630582905 Mon Sep 17 00:00:00 2001 From: Konstantin Nikolaou <87869540+KonstiNik@users.noreply.github.com> Date: Mon, 10 Jun 2024 15:06:09 +0200 Subject: [PATCH] Konsti papyrus recording (#121) * Write independant module for computing the NTK * Remove NTK computation from the model Adapt all tests * Run black and isort * remove ntk from huggingface flax model * Adapt integration tests to new ntk computation * Adapt notebook API to new ntk computation * fix ntk computation hfmodels * remove unnessecary imports * Install papyrus via requirements.txt * Create Papyrus Jax Recorder and tests * make NTK computation return ntk in list * Create tests for the jax ntk computation * remove unused imports from model tests * adapt simple training to new recorder * adapt jax recorder to api changes of papyrus * Adapt Examples to new api - Computing CVs Example - Contrastive Loss Example - Using data recorder Example * run black * Adapt training strategies to new recorder Adapt examples: - Using Training Strategies - ResNet Example * Adapt tests to new recorders * Adapt the trace opt to more flexible ntk calculation. This allows for subsampling the ntk. * Add explenation to the data recorders notebook * include access token in papyrus installation * change access token reference * try granting access to token via environment * exclude token from install commant * try setting git config including access token * include access token also in doc building * Write NTK subsampling class. * Make ntk calculation entire data set with inputs and targets. Add option to set the data set keys in the init of the ntk computation * run black * Write class-wise ntk computation * make loss derivative computation work and include in example * Write ntk combinations which can be used to compute the mutual information of a system. * Move ntk computation to analysis * Fix bugs in jax combinations * Complete Docstrings of papyrus jax recorder * Write Example to compute the information neural mutual information * remove outdated recorder * Adapt integration test to new recorder * Include changes of papyrus to public. --------- Co-authored-by: knikolaou <> --- .github/workflows/doc.yml | 3 + .github/workflows/pytest.yml | 1 + .../test_huggingface_flax_model_deployment.py | 2 - .../test_training_recording_deployment.py | 166 +++-- CI/unit_tests/analysis/test_jax_ntk.py | 180 +++++ .../analysis/test_jax_ntk_classwise.py | 194 +++++ .../analysis/test_jax_ntk_combinations.py | 258 +++++++ .../analysis/test_jax_ntk_subsampling.py | 133 ++++ .../models/_test_huggingface_flax_model.py | 19 +- CI/unit_tests/models/test_flax_model.py | 26 +- CI/unit_tests/models/test_nt_model.py | 19 +- CI/unit_tests/models/test_seed.py | 2 - .../optimizers/test_trace_optimizer.py | 14 +- .../training_recording/test_data_storage.py | 222 +++--- .../training_recording/test_jax_recorder.py | 195 +++++ .../test_training_recording.py | 342 ++++----- CI/unit_tests/utils/test_matrix_utils.py | 2 +- examples/CIFAR10.ipynb | 2 +- examples/Computing-Collective-Variables.ipynb | 98 ++- examples/Contrastive-Loss.ipynb | 177 +++-- examples/Neural-Mutual-Information.ipynb | 404 ++++++++++ examples/ResNet-Example.ipynb | 59 +- examples/Using-Training-Strategies.ipynb | 90 ++- examples/Using-the-Data-Recorders.ipynb | 66 +- requirements.txt | 3 +- znnl/agents/approximate_maximum_entropy.py | 10 +- znnl/analysis/__init__.py | 8 + znnl/analysis/jax_ntk.py | 172 +++++ znnl/analysis/jax_ntk_classwise.py | 242 ++++++ znnl/analysis/jax_ntk_combinations.py | 315 ++++++++ znnl/analysis/jax_ntk_subsampling.py | 229 ++++++ znnl/models/flax_model.py | 33 +- znnl/models/huggingface_flax_model.py | 37 +- znnl/models/jax_model.py | 88 +-- znnl/models/nt_model.py | 35 +- znnl/optimizers/trace_optimizer.py | 3 +- znnl/training_recording/__init__.py | 7 +- znnl/training_recording/jax_recording.py | 688 ------------------ .../papyrus_jax_recording.py | 255 +++++++ .../loss_aware_reservoir.py | 10 +- .../partitioned_training.py | 10 +- znnl/training_strategies/simple_training.py | 17 +- 42 files changed, 3378 insertions(+), 1458 deletions(-) create mode 100644 CI/unit_tests/analysis/test_jax_ntk.py create mode 100644 CI/unit_tests/analysis/test_jax_ntk_classwise.py create mode 100644 CI/unit_tests/analysis/test_jax_ntk_combinations.py create mode 100644 CI/unit_tests/analysis/test_jax_ntk_subsampling.py create mode 100644 CI/unit_tests/training_recording/test_jax_recorder.py create mode 100644 examples/Neural-Mutual-Information.ipynb create mode 100644 znnl/analysis/jax_ntk.py create mode 100644 znnl/analysis/jax_ntk_classwise.py create mode 100644 znnl/analysis/jax_ntk_combinations.py create mode 100644 znnl/analysis/jax_ntk_subsampling.py delete mode 100644 znnl/training_recording/jax_recording.py create mode 100644 znnl/training_recording/papyrus_jax_recording.py diff --git a/.github/workflows/doc.yml b/.github/workflows/doc.yml index 5d84512..9355eb8 100644 --- a/.github/workflows/doc.yml +++ b/.github/workflows/doc.yml @@ -13,8 +13,11 @@ jobs: with: python-version: '3.11' - name: Install dependencies + env: + super_secret: ${{ secrets.PAPYRUS_ACCESS }} run: | sudo apt install pandoc + git config --global url."https://${{ secrets.PAPYRUS_ACCESS }}@github".insteadOf https://github pip install -r dev-requirements.txt pip install -r requirements.txt pip install . diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index cc6aff5..2575604 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -21,6 +21,7 @@ jobs: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | + git config --global url."https://${{ secrets.PAPYRUS_ACCESS }}@github".insteadOf https://github python -m pip install --upgrade pip pip install -r dev-requirements.txt pip install -r requirements.txt diff --git a/CI/integration_tests/models/test_huggingface_flax_model_deployment.py b/CI/integration_tests/models/test_huggingface_flax_model_deployment.py index 11ad62f..0bd68a4 100644 --- a/CI/integration_tests/models/test_huggingface_flax_model_deployment.py +++ b/CI/integration_tests/models/test_huggingface_flax_model_deployment.py @@ -95,12 +95,10 @@ def setup_class(cls): cls.resnet18 = HuggingFaceFlaxModel( resnet18, optax.sgd(learning_rate=1e-3), - batch_size=3, ) cls.resnet50 = HuggingFaceFlaxModel( resnet50, optax.sgd(learning_rate=1e-3), - batch_size=3, ) key = random.PRNGKey(0) diff --git a/CI/integration_tests/training_recording/test_training_recording_deployment.py b/CI/integration_tests/training_recording/test_training_recording_deployment.py index 5084432..c8699c0 100644 --- a/CI/integration_tests/training_recording/test_training_recording_deployment.py +++ b/CI/integration_tests/training_recording/test_training_recording_deployment.py @@ -37,8 +37,10 @@ import optax from neural_tangents import stax from numpy import testing +from papyrus.measurements import NTK, Accuracy, Loss -import znnl as rnd +import znnl as nl +from znnl.analysis import JAXNTKComputation class TestRecorderDeployment: @@ -56,7 +58,7 @@ def setup_class(cls): Prepare the class for running. """ # Data Generator - cls.data_generator = rnd.data.MNISTGenerator(ds_size=10) + cls.data_generator = nl.data.MNISTGenerator(ds_size=10) # Make a network network = stax.serial( @@ -64,28 +66,50 @@ def setup_class(cls): ) # Set the class assigned recorders - cls.train_recorder = rnd.training_recording.JaxRecorder( - loss=True, accuracy=True, update_rate=1, chunk_size=11, name="trainer" + cls.train_recorder = nl.training_recording.JaxRecorder( + storage_path=".", + name="trainer", + update_rate=1, + chunk_size=11, + measurements=[ + Loss(apply_fn=nl.loss_functions.MeanPowerLoss(order=2)), + Accuracy(apply_fn=nl.accuracy_functions.LabelAccuracy()), + ], ) - cls.test_recorder = rnd.training_recording.JaxRecorder( - loss=True, accuracy=True, ntk=True, update_rate=5 + cls.test_recorder = nl.training_recording.JaxRecorder( + storage_path=".", + name="tester", + update_rate=5, + measurements=[ + Loss(apply_fn=nl.loss_functions.MeanPowerLoss(order=2)), + Accuracy(apply_fn=nl.accuracy_functions.LabelAccuracy()), + NTK(), + ], ) # Define the model - cls.production_model = rnd.models.NTModel( + cls.production_model = nl.models.NTModel( nt_module=network, optimizer=optax.adam(learning_rate=0.01), input_shape=(1, 28, 28, 1), ) - cls.train_recorder.instantiate_recorder(data_set=cls.data_generator.train_ds) - cls.test_recorder.instantiate_recorder(data_set=cls.data_generator.test_ds) + cls.train_recorder.instantiate_recorder( + data_set=cls.data_generator.train_ds, + model=cls.production_model, + ntk_computation=JAXNTKComputation(cls.production_model.ntk_apply_fn), + ) + cls.test_recorder.instantiate_recorder( + data_set=cls.data_generator.test_ds, + model=cls.production_model, + ntk_computation=JAXNTKComputation(cls.production_model.ntk_apply_fn), + ) # Define training strategy - cls.training_strategy = rnd.training_strategies.SimpleTraining( + cls.training_strategy = nl.training_strategies.SimpleTraining( model=cls.production_model, - loss_fn=rnd.loss_functions.CrossEntropyLoss(), - accuracy_fn=rnd.accuracy_functions.LabelAccuracy(), + loss_fn=nl.loss_functions.CrossEntropyLoss(), + accuracy_fn=nl.accuracy_functions.LabelAccuracy(), recorders=[cls.train_recorder, cls.test_recorder], ) # Train the model with the recorders @@ -109,33 +133,44 @@ def test_private_arrays(self): """ Test that the recorder internally holds the correct values. """ - assert len(self.train_recorder._loss_array) == 10 - assert onp.sum(self.train_recorder._loss_array) > 0 - assert len(self.train_recorder._accuracy_array) == 10 - assert onp.sum(self.train_recorder._accuracy_array) > 0 + assert len(self.train_recorder._results["loss"]) == 10 + assert onp.sum(self.train_recorder._results["loss"]) > 0 + assert len(self.train_recorder._results["accuracy"]) == 10 + assert onp.sum(self.train_recorder._results["accuracy"]) > 0 - assert len(self.test_recorder._loss_array) == 2 - assert len(self.test_recorder._accuracy_array) == 2 - assert onp.sum(self.test_recorder._loss_array) > 0 - assert onp.sum(self.test_recorder._accuracy_array) > 0 + assert len(self.test_recorder._results["loss"]) == 2 + assert len(self.test_recorder._results["accuracy"]) == 2 + assert onp.sum(self.test_recorder._results["loss"]) > 0 + assert onp.sum(self.test_recorder._results["accuracy"]) > 0 def test_data_dump(self): """ Test that the data dumping works correctly. """ with tempfile.TemporaryDirectory() as directory: + new_model = copy.deepcopy(self.production_model) - train_recorder = copy.deepcopy(self.train_recorder) - train_recorder.storage_path = directory + + train_recorder = nl.training_recording.JaxRecorder( + storage_path=directory, + name="trainer", + update_rate=1, + chunk_size=11, + measurements=[ + Loss(apply_fn=nl.loss_functions.MeanPowerLoss(order=2)), + Accuracy(apply_fn=nl.accuracy_functions.LabelAccuracy()), + ], + ) train_recorder.instantiate_recorder( - train_recorder._data_set, overwrite=True + data_set=self.data_generator.train_ds, + model=new_model, ) # Define the training strategy - training_strategy = rnd.training_strategies.SimpleTraining( + training_strategy = nl.training_strategies.SimpleTraining( model=new_model, - loss_fn=rnd.loss_functions.CrossEntropyLoss(), - accuracy_fn=rnd.accuracy_functions.LabelAccuracy(), + loss_fn=nl.loss_functions.CrossEntropyLoss(), + accuracy_fn=nl.accuracy_functions.LabelAccuracy(), recorders=[train_recorder], ) @@ -147,16 +182,19 @@ def test_data_dump(self): epochs=20, ) + # Print all files in the directory + print(f"Files in directory: {os.listdir(directory)}") + # Check if there is data in database with hf.File(f"{directory}/trainer.h5", "r") as db: db_loss = onp.array(db["loss"]) db_accuracy = onp.array(db["accuracy"]) - class_loss = onp.array(train_recorder._loss_array) - class_accuracy = onp.array(train_recorder._accuracy_array) + class_loss = onp.array(train_recorder._results["loss"]) + class_accuracy = onp.array(train_recorder._results["accuracy"]) - assert db_loss.shape == (11,) - assert class_loss.shape == (9,) + assert db_loss.shape == (11, 1) + assert class_loss.shape == (9, 1) testing.assert_raises( AssertionError, testing.assert_array_equal, @@ -164,8 +202,8 @@ def test_data_dump(self): class_loss.sum(), ) - assert db_accuracy.shape == (11,) - assert class_accuracy.shape == (9,) + assert db_accuracy.shape == (11, 1) + assert class_accuracy.shape == (9, 1) testing.assert_raises( AssertionError, testing.assert_array_equal, @@ -177,19 +215,19 @@ def test_export_function_no_db(self): """ Test that the reports are exported correctly. """ - train_report = self.train_recorder.gather_recording() - test_report = self.test_recorder.gather_recording() + train_report = self.train_recorder.gather() + test_report = self.test_recorder.gather() - assert len(train_report.loss) == 10 - assert onp.sum(train_report.loss) > 0 - assert len(train_report.accuracy) == 10 - assert onp.sum(train_report.accuracy) > 0 + assert len(train_report["loss"]) == 10 + assert onp.sum(train_report["loss"]) > 0 + assert len(train_report["accuracy"]) == 10 + assert onp.sum(train_report["accuracy"]) > 0 # Arrays should be resized now. - assert len(test_report.loss) == 2 - assert onp.sum(test_report.loss) > 0 - assert len(test_report.accuracy) == 2 - assert onp.sum(test_report.accuracy) > 0 + assert len(test_report["loss"]) == 2 + assert onp.sum(test_report["loss"]) > 0 + assert len(test_report["accuracy"]) == 2 + assert onp.sum(test_report["accuracy"]) > 0 def test_export_function_db(self): """ @@ -197,16 +235,27 @@ def test_export_function_db(self): """ with tempfile.TemporaryDirectory() as directory: new_model = copy.deepcopy(self.production_model) - train_recorder = copy.deepcopy(self.train_recorder) - train_recorder.storage_path = directory + + train_recorder = nl.training_recording.JaxRecorder( + storage_path=directory, + name="trainer", + update_rate=1, + chunk_size=11, + measurements=[ + Loss(apply_fn=nl.loss_functions.MeanPowerLoss(order=2)), + Accuracy(apply_fn=nl.accuracy_functions.LabelAccuracy()), + ], + ) train_recorder.instantiate_recorder( - train_recorder._data_set, overwrite=True + data_set=self.data_generator.train_ds, + model=new_model, ) + # Define the training strategy - training_strategy = rnd.training_strategies.SimpleTraining( + training_strategy = nl.training_strategies.SimpleTraining( model=new_model, - loss_fn=rnd.loss_functions.CrossEntropyLoss(), - accuracy_fn=rnd.accuracy_functions.LabelAccuracy(), + loss_fn=nl.loss_functions.CrossEntropyLoss(), + accuracy_fn=nl.accuracy_functions.LabelAccuracy(), recorders=[train_recorder], ) @@ -218,19 +267,8 @@ def test_export_function_db(self): epochs=20, ) - report = train_recorder.gather_recording() - assert report.loss.shape[0] == 20 - testing.assert_array_equal(report.loss[11:], train_recorder._loss_array) - - def test_export_function_no_db_custom_selection(self): - """ - Test that the reports are exported correctly. - """ - # Note, NTK is not recorded, it should be caught and removed. - train_report = self.train_recorder.gather_recording( - selected_properties=["loss", "ntk"] - ) - - assert len(train_report.loss) == 10 - assert onp.sum(train_report.loss) > 0 - assert "ntk" not in list(train_report.__dict__) + report = train_recorder.gather() + assert report["loss"].shape[0] == 20 + testing.assert_array_equal( + report["loss"][11:], train_recorder._results["loss"] + ) diff --git a/CI/unit_tests/analysis/test_jax_ntk.py b/CI/unit_tests/analysis/test_jax_ntk.py new file mode 100644 index 0000000..1836bf0 --- /dev/null +++ b/CI/unit_tests/analysis/test_jax_ntk.py @@ -0,0 +1,180 @@ +""" +ZnNL: A Zincwarecode package. + +License +------- +This program and the accompanying materials are made available under the terms +of the Eclipse Public License v2.0 which accompanies this distribution, and is +available at https://www.eclipse.org/legal/epl-v20.html + +SPDX-License-Identifier: EPL-2.0 + +Copyright Contributors to the Zincwarecode Project. + +Contact Information +------------------- +email: zincwarecode@gmail.com +github: https://github.com/zincware +web: https://zincwarecode.com/ + +Citation +-------- +If you use this module please cite us with: + +Summary +------- +""" + +import jax.numpy as np +import neural_tangents as nt +import optax +from flax import linen as nn +from jax import random + +from znnl.analysis import JAXNTKComputation +from znnl.models import FlaxModel + + +class FlaxTestModule(nn.Module): + """ + Test model for the Flax tests. + """ + + @nn.compact + def __call__(self, x): + x = nn.Dense(5, use_bias=True)(x) + x = nn.relu(x) + x = nn.Dense(features=2, use_bias=True)(x) + return x + + +class TestJAXNTKComputation: + """ + Test class for the JAX NTK computation class. + """ + + @classmethod + def setup_class(cls): + """ + Setup the test class. + """ + cls.flax_model = FlaxModel( + flax_module=FlaxTestModule(), + optimizer=optax.adam(learning_rate=0.001), + input_shape=(8,), + seed=17, + ) + + cls.dataset = { + "inputs": random.normal(random.PRNGKey(0), (10, 8)), + "targets": random.normal(random.PRNGKey(1), (10, 2)), + } + + def test_constructor(self): + """ + Test the constructor of the JAX NTK computation class. + """ + apply_fn = lambda x: x + batch_size = 10 + ntk_implementation = None + trace_axes = (-1,) + store_on_device = False + flatten = True + data_keys = ["image", "label"] + + jax_ntk_computation = JAXNTKComputation( + apply_fn=apply_fn, + batch_size=batch_size, + ntk_implementation=ntk_implementation, + trace_axes=trace_axes, + store_on_device=store_on_device, + flatten=flatten, + data_keys=data_keys, + ) + + assert jax_ntk_computation.apply_fn == apply_fn + assert jax_ntk_computation.batch_size == batch_size + assert jax_ntk_computation.trace_axes == trace_axes + assert jax_ntk_computation.store_on_device == store_on_device + assert jax_ntk_computation.flatten == flatten + assert jax_ntk_computation.data_keys == data_keys + + def test_constructor_default(self): + """ + Test the default setting of the constructor of the JAX NTK computation class. + """ + apply_fn = lambda x: x + + jax_ntk_computation = JAXNTKComputation( + apply_fn=apply_fn, + ) + + assert jax_ntk_computation.apply_fn == apply_fn + assert jax_ntk_computation.batch_size == 10 + assert jax_ntk_computation.trace_axes == () + assert jax_ntk_computation.store_on_device == False + assert jax_ntk_computation.flatten == True + assert jax_ntk_computation.data_keys == ["inputs", "targets"] + + # Default ntk_implementation should be NTK_VECTOR_PRODUCTS + assert ( + jax_ntk_computation.ntk_implementation + == nt.NtkImplementation.NTK_VECTOR_PRODUCTS + ) + + def test_check_shape(self): + """ + Test the shape checking function. + """ + jax_ntk_computation = JAXNTKComputation(apply_fn=self.flax_model.ntk_apply_fn) + + ntk = np.ones((10, 10, 3, 3)) + ntk_ = jax_ntk_computation._check_shape(ntk) + + assert jax_ntk_computation._is_flattened == True + assert ntk_.shape == (30, 30) + + def test_compute_ntk(self): + """ + Test the computation of the NTK. + """ + params = {"params": self.flax_model.model_state.params} + + # Trace axes is empty and flatten is True + jax_ntk_computation = JAXNTKComputation( + apply_fn=self.flax_model.ntk_apply_fn, + trace_axes=(), + flatten=True, + ) + ntk = jax_ntk_computation.compute_ntk(params, self.dataset) + assert np.shape(ntk) == (1, 20, 20) + + # Trace axes is empty and flatten is False + jax_ntk_computation = JAXNTKComputation( + apply_fn=self.flax_model.ntk_apply_fn, + trace_axes=(), + flatten=False, + ) + ntk = jax_ntk_computation.compute_ntk(params, self.dataset) + + assert np.shape(ntk) == (1, 10, 10, 2, 2) + + # Trace axes is (-1,) and flatten is True + jax_ntk_computation = JAXNTKComputation( + apply_fn=self.flax_model.ntk_apply_fn, + trace_axes=(-1,), + flatten=True, + ) + ntk = jax_ntk_computation.compute_ntk(params, self.dataset) + + assert np.shape(ntk) == (1, 10, 10) + + # Trace axes is (-1,) and flatten is False + jax_ntk_computation = JAXNTKComputation( + apply_fn=self.flax_model.ntk_apply_fn, + trace_axes=(-1,), + flatten=False, + ) + ntk = jax_ntk_computation.compute_ntk(params, self.dataset) + + assert np.shape(ntk) == (1, 10, 10) diff --git a/CI/unit_tests/analysis/test_jax_ntk_classwise.py b/CI/unit_tests/analysis/test_jax_ntk_classwise.py new file mode 100644 index 0000000..f25f0c2 --- /dev/null +++ b/CI/unit_tests/analysis/test_jax_ntk_classwise.py @@ -0,0 +1,194 @@ +""" +ZnNL: A Zincwarecode package. + +License +------- +This program and the accompanying materials are made available under the terms +of the Eclipse Public License v2.0 which accompanies this distribution, and is +available at https://www.eclipse.org/legal/epl-v20.html + +SPDX-License-Identifier: EPL-2.0 + +Copyright Contributors to the Zincwarecode Project. + +Contact Information +------------------- +email: zincwarecode@gmail.com +github: https://github.com/zincware +web: https://zincwarecode.com/ + +Citation +-------- +If you use this module please cite us with: + +Summary +------- +""" + +import jax.numpy as np +import optax +from flax import linen as nn +from jax import random + +from znnl.analysis import JAXNTKClassWise +from znnl.models import FlaxModel + + +class FlaxTestModule(nn.Module): + """ + Test model for the Flax tests. + """ + + @nn.compact + def __call__(self, x): + x = nn.Dense(5, use_bias=True)(x) + x = nn.relu(x) + x = nn.Dense(features=2, use_bias=True)(x) + return x + + +class TestJAXNTKClassWise: + """ + Test class for the class-wise JAX NTK computation. + """ + + @classmethod + def setup_class(cls): + """ + Setup the test class. + """ + cls.flax_model = FlaxModel( + flax_module=FlaxTestModule(), + optimizer=optax.adam(learning_rate=0.001), + input_shape=(8,), + seed=17, + ) + + # Create random labels between zero and two + targets = np.array([0, 1, 2, 0, 1, 2, 0, 0]) + one_hot_targets = np.eye(3)[targets] + + cls.dataset_int = { + "inputs": random.normal(random.PRNGKey(0), (8, 8)), + "targets": np.expand_dims(targets, axis=1), + } + cls.dataset_onehot = { + "inputs": random.normal(random.PRNGKey(0), (8, 8)), + "targets": one_hot_targets, + } + + def test_constructor(self): + """ + Test the constructor of the JAX NTK computation class. + """ + jax_ntk = JAXNTKClassWise( + apply_fn=self.flax_model.apply, + ) + + assert jax_ntk.batch_size == 10 + assert jax_ntk._sample_indices == None + + def test_get_label_indices(self): + """ + Test the _get_label_indices method. + """ + jax_ntk = JAXNTKClassWise( + apply_fn=self.flax_model.apply, + ) + + # Test the one-hot targets + sample_idx_one_hot = jax_ntk._get_label_indices(self.dataset_onehot) + assert len(sample_idx_one_hot) == 3 + assert len(sample_idx_one_hot[0]) == 4 + assert len(sample_idx_one_hot[1]) == 2 + assert len(sample_idx_one_hot[2]) == 2 + + # Test the integer targets + sample_idx_int = jax_ntk._get_label_indices(self.dataset_int) + assert len(sample_idx_int) == 3 + assert len(sample_idx_int[0]) == 4 + assert len(sample_idx_int[1]) == 2 + assert len(sample_idx_int[2]) == 2 + + # Test upper bound of ntk_size + jax_ntk.ntk_size = 3 + sample_idx_one_hot = jax_ntk._get_label_indices(self.dataset_onehot) + assert len(sample_idx_one_hot) == 3 + assert len(sample_idx_one_hot[0]) == 3 + assert len(sample_idx_one_hot[1]) == 2 + assert len(sample_idx_one_hot[2]) == 2 + + def test_subsample_data(self): + """ + Test the _subsample_data method. + """ + jax_ntk = JAXNTKClassWise( + apply_fn=self.flax_model.apply, + ) + + # Test the one-hot targets + subsampled_data_one_hot = jax_ntk._subsample_data( + self.dataset_onehot["inputs"], + jax_ntk._get_label_indices(self.dataset_onehot), + ) + assert len(subsampled_data_one_hot) == 3 + assert subsampled_data_one_hot[0].shape == (4, 8) + assert subsampled_data_one_hot[1].shape == (2, 8) + assert subsampled_data_one_hot[2].shape == (2, 8) + + # Test the integer targets + subsampled_data_int = jax_ntk._subsample_data( + self.dataset_int["inputs"], jax_ntk._get_label_indices(self.dataset_int) + ) + assert len(subsampled_data_int) == 3 + assert subsampled_data_int[0].shape == (4, 8) + assert subsampled_data_int[1].shape == (2, 8) + assert subsampled_data_int[2].shape == (2, 8) + + def test_compute_ntk(self): + """ + Test the compute_ntk method. + """ + jax_ntk = JAXNTKClassWise( + apply_fn=self.flax_model.ntk_apply_fn, + batch_size=10, + ) + + params = {"params": self.flax_model.model_state.params} + + # Test the one-hot targets + ntks = jax_ntk.compute_ntk(params, self.dataset_onehot) + assert len(ntks) == 3 + assert ntks[0].shape == (8, 8) + assert ntks[1].shape == (4, 4) + assert ntks[2].shape == (4, 4) + + # Test the integer targets + ntks = jax_ntk.compute_ntk(params, self.dataset_int) + print(ntks) + assert len(ntks) == 3 + assert ntks[0].shape == (8, 8) + assert ntks[1].shape == (4, 4) + assert ntks[2].shape == (4, 4) + + # Test if not all classes are present + dataset = { + "inputs": self.dataset_int["inputs"], + "targets": np.array([0, 0, 0, 0, 0, 0, 0, 0]), + } + ntks = jax_ntk.compute_ntk(params, dataset) + assert len(ntks) == 1 + assert ntks[0].shape == (16, 16) + + dataset = { + "inputs": self.dataset_int["inputs"], + "targets": np.array([0, 0, 0, 0, 0, 0, 0, 5]), + } + ntks = jax_ntk.compute_ntk(params, dataset) + assert len(ntks) == 6 + assert ntks[0].shape == (14, 14) + assert ntks[1].shape == (0, 0) + assert ntks[2].shape == (0, 0) + assert ntks[3].shape == (0, 0) + assert ntks[4].shape == (0, 0) + assert ntks[5].shape == (2, 2) diff --git a/CI/unit_tests/analysis/test_jax_ntk_combinations.py b/CI/unit_tests/analysis/test_jax_ntk_combinations.py new file mode 100644 index 0000000..3258526 --- /dev/null +++ b/CI/unit_tests/analysis/test_jax_ntk_combinations.py @@ -0,0 +1,258 @@ +""" +ZnNL: A Zincwarecode package. + +License +------- +This program and the accompanying materials are made available under the terms +of the Eclipse Public License v2.0 which accompanies this distribution, and is +available at https://www.eclipse.org/legal/epl-v20.html + +SPDX-License-Identifier: EPL-2.0 + +Copyright Contributors to the Zincwarecode Project. + +Contact Information +------------------- +email: zincwarecode@gmail.com +github: https://github.com/zincware +web: https://zincwarecode.com/ + +Citation +-------- +If you use this module please cite us with: + +Summary +------- +""" + +import jax.numpy as np +import neural_tangents as nt +import optax +from flax import linen as nn +from jax import random +from papyrus.utils.matrix_utils import flatten_rank_4_tensor + +from znnl.analysis import JAXNTKCombinations +from znnl.models import FlaxModel + + +class FlaxTestModule(nn.Module): + """ + Test model for the Flax tests. + """ + + @nn.compact + def __call__(self, x): + x = nn.Dense(5, use_bias=True)(x) + x = nn.relu(x) + x = nn.Dense(features=2, use_bias=True)(x) + return x + + +class FlaxTestModule(nn.Module): + """ + Test model for the Flax tests. + """ + + @nn.compact + def __call__(self, x): + x = nn.Dense(5, use_bias=True)(x) + x = nn.relu(x) + x = nn.Dense(features=2, use_bias=True)(x) + return x + + +class TestJAXNTKClassWise: + """ + Test class for the class-wise JAX NTK computation. + """ + + @classmethod + def setup_class(cls): + """ + Setup the test class. + """ + cls.flax_model = FlaxModel( + flax_module=FlaxTestModule(), + optimizer=optax.adam(learning_rate=0.001), + input_shape=(8,), + seed=17, + ) + + # Create random labels between zero and two + targets = np.array([0, 1, 2, 0, 1, 2, 0, 0]) + one_hot_targets = np.eye(3)[targets] + + cls.dataset_int = { + "inputs": random.normal(random.PRNGKey(0), (8, 8)), + "targets": np.expand_dims(targets, axis=1), + } + cls.dataset_onehot = { + "inputs": random.normal(random.PRNGKey(0), (8, 8)), + "targets": one_hot_targets, + } + + def test_constructor(self): + """ + Test the constructor of the JAX NTK computation class. + """ + jax_ntk = JAXNTKCombinations( + apply_fn=self.flax_model.apply, + class_labels=[0, 1, 2], + ) + + assert jax_ntk.class_labels == [0, 1, 2] + + def test_reduce_data_to_labels(self): + """ + Test the _reduce_data_to_labels method. + """ + jax_ntk = JAXNTKCombinations( + apply_fn=self.flax_model.apply, + class_labels=[0, 1], + ) + + # Test the one-hot targets + reduced_data = jax_ntk._reduce_data_to_labels(self.dataset_onehot) + assert reduced_data["inputs"].shape == (6, 8) + assert reduced_data["targets"].shape == (6, 3) + + # Test the integer targets + reduced_data = jax_ntk._reduce_data_to_labels(self.dataset_int) + assert reduced_data["inputs"].shape == (6, 8) + assert reduced_data["targets"].shape == (6, 1) + + def test_get_label_indices(self): + """ + Test the _get_label_indices method. + """ + jax_ntk = JAXNTKCombinations( + apply_fn=self.flax_model.apply, + class_labels=[0, 1, 2], + ) + + # Test the one-hot targets + sample_idx_one_hot = jax_ntk._get_label_indices(self.dataset_onehot) + assert type(sample_idx_one_hot) == dict + assert len(sample_idx_one_hot) == 3 + assert len(sample_idx_one_hot[0]) == 4 + assert len(sample_idx_one_hot[1]) == 2 + assert len(sample_idx_one_hot[2]) == 2 + + # Test the integer targets + sample_idx_int = jax_ntk._get_label_indices(self.dataset_int) + assert type(sample_idx_int) == dict + assert len(sample_idx_int) == 3 + assert len(sample_idx_int[0]) == 4 + assert len(sample_idx_int[1]) == 2 + assert len(sample_idx_int[2]) == 2 + + def test_compute_combinations(self): + """ + Test the _compute_combinations method. + """ + + jax_ntk = JAXNTKCombinations( + apply_fn=self.flax_model.apply, + class_labels=[0, 1], + ) + combinations = jax_ntk._compute_combinations() + assert combinations == [(0,), (1,), (0, 1)] + + jax_ntk = JAXNTKCombinations( + apply_fn=self.flax_model.apply, + class_labels=[0, 1, 2], + ) + combinations = jax_ntk._compute_combinations() + assert combinations == [(0,), (1,), (2,), (0, 1), (0, 2), (1, 2), (0, 1, 2)] + + def test_take_sub_ntk(self): + """ + Test the _take_sub_ntk method. + """ + jax_ntk = JAXNTKCombinations( + apply_fn=self.flax_model.apply, + class_labels=[0, 1], + ) + + # Test shape flattened NTK + jax_ntk._ntk_shape = (8, 8, 2, 2) + jax_ntk._is_flattened = True + jax_ntk.flatten = True + ntk = random.normal(random.PRNGKey(0), (16, 16)) + reduced_data = jax_ntk._reduce_data_to_labels(self.dataset_int) + label_indices = jax_ntk._get_label_indices(reduced_data) + combination = (0, 1) + sub_ntk = jax_ntk._take_sub_ntk(ntk, label_indices, combination) + assert sub_ntk.shape == (12, 12) + + # Test shape unflattened NTK + jax_ntk._ntk_shape = (8, 8, 2, 2) + jax_ntk._is_flattened = False + jax_ntk.flatten = False + ntk = random.normal(random.PRNGKey(0), (8, 8, 2, 2)) + reduced_data = jax_ntk._reduce_data_to_labels(self.dataset_int) + label_indices = jax_ntk._get_label_indices(reduced_data) + combination = (0, 1) + sub_ntk = jax_ntk._take_sub_ntk(ntk, label_indices, combination) + assert sub_ntk.shape == (6, 6, 2, 2) + + # Test entries of the sub-NTK + jax_ntk._ntk_shape = (4, 4, 2, 2) + jax_ntk._is_flattened = True + jax_ntk.flatten = True + # Create some easier to check the sub-NTK + targets = np.array([0, 0, 1, 1, 2, 2, 2, 2]) + dataset = { + "inputs": self.dataset_int["inputs"], + "targets": np.expand_dims(targets, axis=1), + } + # Reduce data to given labels and get label indices + reduced_data = jax_ntk._reduce_data_to_labels(dataset) + label_indices = jax_ntk._get_label_indices(reduced_data) + # NTK of selected labels + ntk = np.arange(8 * 8).reshape((4, 4, 2, 2)) + combination = (0,) + # This is what should be extracted + _sub_ntk = ntk[np.ix_(label_indices[0], label_indices[0])] + _sub_ntk, _ = flatten_rank_4_tensor(_sub_ntk) + # Compute the sub-NTK + ntk, _ = flatten_rank_4_tensor(ntk) + sub_ntk = jax_ntk._take_sub_ntk(ntk, label_indices, combination) + # Check if the sub-NTK is correct + assert np.all(sub_ntk == _sub_ntk) + + def test_compute_ntk(self): + """ + Test the compute_ntk method. + """ + + jax_ntk = JAXNTKCombinations( + apply_fn=self.flax_model.apply, + class_labels=[0, 1], + ) + + params = {"params": self.flax_model.model_state.params} + + ntks = jax_ntk.compute_ntk(params, self.dataset_int) + assert len(ntks) == 3 + assert ntks[0].shape == (8, 8) + assert ntks[1].shape == (4, 4) + assert ntks[2].shape == (12, 12) + + jax_ntk = JAXNTKCombinations( + apply_fn=self.flax_model.apply, + class_labels=[0, 1, 2], + ) + + ntks = jax_ntk.compute_ntk(params, self.dataset_int) + assert len(ntks) == 7 + assert [np.shape(ntk) for ntk in ntks] == [ + (8, 8), + (4, 4), + (4, 4), + (12, 12), + (12, 12), + (8, 8), + (16, 16), + ] diff --git a/CI/unit_tests/analysis/test_jax_ntk_subsampling.py b/CI/unit_tests/analysis/test_jax_ntk_subsampling.py new file mode 100644 index 0000000..f425710 --- /dev/null +++ b/CI/unit_tests/analysis/test_jax_ntk_subsampling.py @@ -0,0 +1,133 @@ +""" +ZnNL: A Zincwarecode package. + +License +------- +This program and the accompanying materials are made available under the terms +of the Eclipse Public License v2.0 which accompanies this distribution, and is +available at https://www.eclipse.org/legal/epl-v20.html + +SPDX-License-Identifier: EPL-2.0 + +Copyright Contributors to the Zincwarecode Project. + +Contact Information +------------------- +email: zincwarecode@gmail.com +github: https://github.com/zincware +web: https://zincwarecode.com/ + +Citation +-------- +If you use this module please cite us with: + +Summary +------- +""" + +import jax.numpy as np +import optax +from flax import linen as nn +from jax import random + +from znnl.analysis import JAXNTKSubsampling +from znnl.models import FlaxModel + + +class FlaxTestModule(nn.Module): + """ + Test model for the Flax tests. + """ + + @nn.compact + def __call__(self, x): + x = nn.Dense(5, use_bias=True)(x) + x = nn.relu(x) + x = nn.Dense(features=2, use_bias=True)(x) + return x + + +class TestJAXNTKSubsampling: + """ + Test class for the JAX NTK computation class. + """ + + @classmethod + def setup_class(cls): + """ + Setup the test class. + """ + cls.flax_model = FlaxModel( + flax_module=FlaxTestModule(), + optimizer=optax.adam(learning_rate=0.001), + input_shape=(8,), + seed=17, + ) + + cls.dataset = { + "inputs": random.normal(random.PRNGKey(0), (10, 8)), + "targets": random.normal(random.PRNGKey(1), (10, 2)), + } + + def test_constructor(self): + """ + Test the constructor of the JAX NTK computation class. + """ + jax_ntk = JAXNTKSubsampling( + apply_fn=self.flax_model.ntk_apply_fn, ntk_size=3, seed=0 + ) + + assert jax_ntk.ntk_size == 3 + + def test_get_sample_indices(self): + """ + Test the _get_sample_indices method. + """ + jax_ntk = JAXNTKSubsampling( + apply_fn=self.flax_model.ntk_apply_fn, + ntk_size=3, + seed=0, + ) + + sample_indices = jax_ntk._get_sample_indices(self.dataset["inputs"]) + + assert len(sample_indices) == 3 + assert sample_indices[0].shape == (3,) + assert sample_indices[1].shape == (3,) + assert sample_indices[2].shape == (3,) + + def test_subsample_data(self): + """ + Test the _subsample_data method. + """ + jax_ntk = JAXNTKSubsampling( + apply_fn=self.flax_model.ntk_apply_fn, + ntk_size=3, + seed=0, + ) + + jax_ntk._sample_indices = jax_ntk._get_sample_indices(self.dataset["inputs"]) + subsampled_data = jax_ntk._subsample_data(self.dataset["inputs"]) + + assert len(subsampled_data) == 3 + assert subsampled_data[0].shape == (3, 8) + assert subsampled_data[1].shape == (3, 8) + assert subsampled_data[2].shape == (3, 8) + + def test_compute_ntk(self): + """ + Test the compute_ntk method. + """ + + # Use vmap is False + jax_ntk = JAXNTKSubsampling( + apply_fn=self.flax_model.ntk_apply_fn, + ntk_size=3, + seed=0, + ) + + params = {"params": self.flax_model.model_state.params} + + ntk = jax_ntk.compute_ntk(params, self.dataset) + + assert np.shape(ntk) == (3, 6, 6) diff --git a/CI/unit_tests/models/_test_huggingface_flax_model.py b/CI/unit_tests/models/_test_huggingface_flax_model.py index 7d77063..fc7411f 100644 --- a/CI/unit_tests/models/_test_huggingface_flax_model.py +++ b/CI/unit_tests/models/_test_huggingface_flax_model.py @@ -26,10 +26,10 @@ """ import optax -import pytest from jax import random from transformers import FlaxResNetForImageClassification, ResNetConfig +from znnl.analysis import JAXNTKComputation from znnl.models import HuggingFaceFlaxModel @@ -68,30 +68,27 @@ def setup_class(cls): cls.model = HuggingFaceFlaxModel( hf_model, optax.adam(learning_rate=0.001), - batch_size=3, ) key = random.PRNGKey(0) cls.x = random.normal(key, (3, 2, 8, 8)) + cls.ntk_computation = JAXNTKComputation( + cls.model.ntk_apply_fn, trace_axes=(-1,) + ) + def test_ntk_shape(self): """ Test whether the NTK shape is correct. """ - ntk = self.model.compute_ntk(self.x)["empirical"] + ntk = self.ntk_computation.compute_ntk( + self.model.model_state.params, {"inputs": self.x, "targets": None} + ) assert ntk.shape == (3, 3) - def test_infinite_failure(self): - """ - Test that the call to the infinite NTK fails. - """ - with pytest.raises(NotImplementedError): - self.model.compute_ntk(self.x, infinite=True) - if __name__ == "__main__": test_class = TestFlaxHFModule() test_class.setup_class() - # test_class.test_infinite_failure() test_class.test_ntk_shape() diff --git a/CI/unit_tests/models/test_flax_model.py b/CI/unit_tests/models/test_flax_model.py index d6bf1c3..05db29b 100644 --- a/CI/unit_tests/models/test_flax_model.py +++ b/CI/unit_tests/models/test_flax_model.py @@ -29,11 +29,12 @@ os.environ["CUDA_VISIBLE_DEVICES"] = "-1" +import jax.numpy as np import optax -import pytest from flax import linen as nn from jax import random +from znnl.analysis import JAXNTKComputation from znnl.models import FlaxModel @@ -67,24 +68,11 @@ def test_ntk_shape(self): seed=17, ) - key1, key2 = random.split(random.PRNGKey(1), 2) - x = random.normal(key1, (3, 8)) - ntk = model.compute_ntk(x)["empirical"] - assert ntk.shape == (3, 3) - - def test_infinite_failure(self): - """ - Test that the call to the infinite NTK fails. - """ - model = FlaxModel( - flax_module=FlaxTestModule(), - optimizer=optax.adam(learning_rate=0.001), - input_shape=(8,), - seed=17, - ) + ntk_computation = JAXNTKComputation(model.ntk_apply_fn, trace_axes=(-1,)) key1, key2 = random.split(random.PRNGKey(1), 2) x = random.normal(key1, (3, 8)) - - with pytest.raises(NotImplementedError): - model.compute_ntk(x, infinite=True) + ntk = ntk_computation.compute_ntk( + {"params": model.model_state.params}, {"inputs": x, "targets": None} + ) + assert np.shape(ntk) == (1, 3, 3) diff --git a/CI/unit_tests/models/test_nt_model.py b/CI/unit_tests/models/test_nt_model.py index f719c95..b1114da 100644 --- a/CI/unit_tests/models/test_nt_model.py +++ b/CI/unit_tests/models/test_nt_model.py @@ -29,10 +29,12 @@ os.environ["CUDA_VISIBLE_DEVICES"] = "-1" +import jax.numpy as np import optax from jax import random from neural_tangents import stax +from znnl.analysis import JAXNTKComputation from znnl.models import NTModel @@ -61,7 +63,6 @@ def test_ntk_shape(self): nt_module=test_model, optimizer=optax.adam(learning_rate=0.001), input_shape=(1, 8), - batch_size=1, seed=17, ) @@ -69,13 +70,17 @@ def test_ntk_shape(self): nt_module=test_model, optimizer=optax.adam(learning_rate=0.001), input_shape=(1, 8), - batch_size=1, seed=17, - trace_axes=(), ) - ntk_1 = nt_model_1.compute_ntk(x1)["empirical"] - ntk_2 = nt_model_2.compute_ntk(x1)["empirical"] + ntk_computation_1 = JAXNTKComputation(nt_model_1.ntk_apply_fn, trace_axes=(-1,)) + ntk_1 = ntk_computation_1.compute_ntk( + {"params": nt_model_1.model_state.params}, {"inputs": x1, "targets": None} + ) + ntk_computation_2 = JAXNTKComputation(nt_model_2.ntk_apply_fn, trace_axes=()) + ntk_2 = ntk_computation_2.compute_ntk( + {"params": nt_model_2.model_state.params}, {"inputs": x1, "targets": None} + ) - assert ntk_1.shape == (3, 3) - assert ntk_2.shape == (3, 3, 5, 5) + assert np.shape(ntk_1) == (1, 3, 3) + assert np.shape(ntk_2) == (1, 15, 15) diff --git a/CI/unit_tests/models/test_seed.py b/CI/unit_tests/models/test_seed.py index 7cc2053..17a8207 100644 --- a/CI/unit_tests/models/test_seed.py +++ b/CI/unit_tests/models/test_seed.py @@ -55,7 +55,6 @@ def test_models(self): nt_module=test_model, optimizer=optax.adam(learning_rate=0.001), input_shape=(1,), - batch_size=1, seed=17, ) @@ -63,7 +62,6 @@ def test_models(self): nt_module=test_model, optimizer=optax.adam(learning_rate=0.001), input_shape=(1,), - batch_size=1, seed=17, ) return nt_model_1, nt_model_2 diff --git a/CI/unit_tests/optimizers/test_trace_optimizer.py b/CI/unit_tests/optimizers/test_trace_optimizer.py index 9cdb12e..55b5e08 100644 --- a/CI/unit_tests/optimizers/test_trace_optimizer.py +++ b/CI/unit_tests/optimizers/test_trace_optimizer.py @@ -32,6 +32,7 @@ import jax.numpy as np from neural_tangents import stax +from znnl.analysis import JAXNTKComputation from znnl.data import MNISTGenerator from znnl.models import NTModel from znnl.optimizers import TraceOptimizer @@ -81,18 +82,23 @@ def test_apply_operation(self): optimizer=optimizer, input_shape=(1, 28, 28, 1), nt_module=network, - batch_size=5, + ) + ntk_computation = JAXNTKComputation( + model.ntk_apply_fn, trace_axes=(-1,), batch_size=5 ) # Get theoretical values - ntk = model.compute_ntk(data.train_ds["inputs"])["empirical"] + ntk = ntk_computation.compute_ntk( + {"params": model.model_state.params}, data.train_ds + ) + ntk = np.array(ntk).mean(axis=0) expected_lr = scale_factor / np.trace(ntk) # Compute actual values actual_lr = optimizer.apply_optimizer( model_state=model.model_state, - data_set=data.train_ds["inputs"], - ntk_fn=model.compute_ntk, + data_set=data.train_ds, + ntk_fn=ntk_computation.compute_ntk, epoch=1, ).opt_state diff --git a/CI/unit_tests/training_recording/test_data_storage.py b/CI/unit_tests/training_recording/test_data_storage.py index 6cff349..9a02c67 100644 --- a/CI/unit_tests/training_recording/test_data_storage.py +++ b/CI/unit_tests/training_recording/test_data_storage.py @@ -1,111 +1,111 @@ -""" -ZnNL: A Zincwarecode package. - -License -------- -This program and the accompanying materials are made available under the terms -of the Eclipse Public License v2.0 which accompanies this distribution, and is -available at https://www.eclipse.org/legal/epl-v20.html - -SPDX-License-Identifier: EPL-2.0 - -Copyright Contributors to the Zincwarecode Project. - -Contact Information -------------------- -email: zincwarecode@gmail.com -github: https://github.com/zincware -web: https://zincwarecode.com/ - -Citation --------- -If you use this module please cite us with: - -Summary -------- -""" - -import tempfile -from dataclasses import dataclass -from os import path -from pathlib import Path - -import h5py as hf -import numpy as onp -from numpy import testing - -from znnl.training_recording import DataStorage - - -@dataclass -class DataClass: - """ - Dummy data class for testing - """ - - vector_data: onp.ndarray - tensor_data: onp.ndarray - - -class TestDataStorage: - """ - Test suite for the storage module. - """ - - @classmethod - def setup_class(cls): - """ - Set up the test. - """ - cls.vector_data = onp.random.uniform(size=(100,)) - cls.tensor_data = onp.random.uniform(size=(100, 10, 10)) - - cls.data_object = DataClass( - vector_data=cls.vector_data, tensor_data=cls.tensor_data - ) - - def test_database_construction(self): - """ - Test that database groups are built properly. - """ - # Create temporary directory for safe testing. - with tempfile.TemporaryDirectory() as directory: - database_path = path.join(directory, "test_creation") - data_storage = DataStorage(Path(database_path)) - data_storage.write_data(self.data_object) # write some data to empty DB. - - with hf.File(data_storage.database_path, "r") as db: - # Test correct dataset creation. - keys = list(db.keys()) - testing.assert_equal(keys, ["tensor_data", "vector_data"]) - vector_data = onp.array(db["vector_data"]) - tensor_data = onp.array(db["tensor_data"]) - - # Check data structure within the db. - assert vector_data.shape == (100,) - assert vector_data.sum() != 0.0 - - assert tensor_data.shape == (100, 10, 10) - assert tensor_data.sum() != 0.0 - - def test_resize_dataset_standard(self): - """ - Test if the datasets are resized properly. - """ - with tempfile.TemporaryDirectory() as directory: - database_path = path.join(directory, "test_resize") - data_storage = DataStorage(Path(database_path)) - data_storage.write_data(self.data_object) # write some data to empty DB. - data_storage.write_data(self.data_object) # force resize. - - with hf.File(data_storage.database_path, "r") as db: - # Test correct dataset creation. - vector_data = onp.array(db["vector_data"]) - tensor_data = onp.array(db["tensor_data"]) - - # Check data structure within the db. - assert vector_data.shape == (200,) - assert vector_data[100:].sum() != 0.0 - - assert tensor_data.shape == (200, 10, 10) - assert tensor_data[100:].sum() != 0.0 +# """ +# ZnNL: A Zincwarecode package. + +# License +# ------- +# This program and the accompanying materials are made available under the terms +# of the Eclipse Public License v2.0 which accompanies this distribution, and is +# available at https://www.eclipse.org/legal/epl-v20.html + +# SPDX-License-Identifier: EPL-2.0 + +# Copyright Contributors to the Zincwarecode Project. + +# Contact Information +# ------------------- +# email: zincwarecode@gmail.com +# github: https://github.com/zincware +# web: https://zincwarecode.com/ + +# Citation +# -------- +# If you use this module please cite us with: + +# Summary +# ------- +# """ + +# import tempfile +# from dataclasses import dataclass +# from os import path +# from pathlib import Path + +# import h5py as hf +# import numpy as onp +# from numpy import testing + +# from znnl.training_recording import DataStorage + + +# @dataclass +# class DataClass: +# """ +# Dummy data class for testing +# """ + +# vector_data: onp.ndarray +# tensor_data: onp.ndarray + + +# class TestDataStorage: +# """ +# Test suite for the storage module. +# """ + +# @classmethod +# def setup_class(cls): +# """ +# Set up the test. +# """ +# cls.vector_data = onp.random.uniform(size=(100,)) +# cls.tensor_data = onp.random.uniform(size=(100, 10, 10)) + +# cls.data_object = DataClass( +# vector_data=cls.vector_data, tensor_data=cls.tensor_data +# ) + +# def test_database_construction(self): +# """ +# Test that database groups are built properly. +# """ +# # Create temporary directory for safe testing. +# with tempfile.TemporaryDirectory() as directory: +# database_path = path.join(directory, "test_creation") +# data_storage = DataStorage(Path(database_path)) +# data_storage.write_data(self.data_object) # write some data to empty DB. + +# with hf.File(data_storage.database_path, "r") as db: +# # Test correct dataset creation. +# keys = list(db.keys()) +# testing.assert_equal(keys, ["tensor_data", "vector_data"]) +# vector_data = onp.array(db["vector_data"]) +# tensor_data = onp.array(db["tensor_data"]) + +# # Check data structure within the db. +# assert vector_data.shape == (100,) +# assert vector_data.sum() != 0.0 + +# assert tensor_data.shape == (100, 10, 10) +# assert tensor_data.sum() != 0.0 + +# def test_resize_dataset_standard(self): +# """ +# Test if the datasets are resized properly. +# """ +# with tempfile.TemporaryDirectory() as directory: +# database_path = path.join(directory, "test_resize") +# data_storage = DataStorage(Path(database_path)) +# data_storage.write_data(self.data_object) # write some data to empty DB. +# data_storage.write_data(self.data_object) # force resize. + +# with hf.File(data_storage.database_path, "r") as db: +# # Test correct dataset creation. +# vector_data = onp.array(db["vector_data"]) +# tensor_data = onp.array(db["tensor_data"]) + +# # Check data structure within the db. +# assert vector_data.shape == (200,) +# assert vector_data[100:].sum() != 0.0 + +# assert tensor_data.shape == (200, 10, 10) +# assert tensor_data[100:].sum() != 0.0 diff --git a/CI/unit_tests/training_recording/test_jax_recorder.py b/CI/unit_tests/training_recording/test_jax_recorder.py new file mode 100644 index 0000000..fc061d7 --- /dev/null +++ b/CI/unit_tests/training_recording/test_jax_recorder.py @@ -0,0 +1,195 @@ +""" +ZnNL: A Zincwarecode package. + +License +------- +This program and the accompanying materials are made available under the terms +of the Eclipse Public License v2.0 which accompanies this distribution, and is +available at https://www.eclipse.org/legal/epl-v20.html + +SPDX-License-Identifier: EPL-2.0 + +Copyright Contributors to the Zincwarecode Project. + +Contact Information +------------------- +email: zincwarecode@gmail.com +github: https://github.com/zincware +web: https://zincwarecode.com/ + +Citation +-------- +If you use this module please cite us with: + +Summary +------- +""" + +import tempfile + +import numpy as onp +import optax +from flax import linen as nn +from numpy.testing import assert_raises +from papyrus.measurements import Accuracy, Loss, NTKTrace + +from znnl.analysis import JAXNTKComputation +from znnl.models import FlaxModel +from znnl.training_recording import JaxRecorder + + +class FlaxTestModule(nn.Module): + """ + Test model for the Flax tests. + """ + + @nn.compact + def __call__(self, x): + x = nn.Dense(5, use_bias=True)(x) + x = nn.relu(x) + x = nn.Dense(features=10, use_bias=True)(x) + return x + + +class TestJaxRecorder: + """ + Unit test suite for the JaxRecorder. + + Tests for parent recorder class are implemented in the papyrus package. + """ + + @classmethod + def setup_class(cls): + """ + Prepare the test suite. + """ + dummy_input = onp.random.uniform(size=(5, 3)) + dummy_target = onp.random.uniform(size=(5, 10)) + cls.dummy_data_set = {"inputs": dummy_input, "targets": dummy_target} + + cls.measurements = [ + Loss(apply_fn=lambda x, y: onp.sum((x - y) ** 2)), + Accuracy(), + NTKTrace(), + ] + cls.neural_state = { + "accuracy": [onp.random.uniform(size=(1,))], + "predictions": [onp.random.uniform(size=(5, 10))], + "targets": [onp.random.uniform(size=(5, 10))], + "ntk": [onp.random.uniform(size=(5, 5))], + } + cls.model = FlaxModel( + flax_module=FlaxTestModule(), + optimizer=optax.adam(learning_rate=0.001), + input_shape=(2, 3), + seed=17, + ) + cls.ntk_computation = JAXNTKComputation( + apply_fn=cls.model.ntk_apply_fn, + trace_axes=(-1,), + ) + + def test_check_keys(self): + """ + Test the check_keys method. + """ + recorder = JaxRecorder( + name="test", + storage_path=tempfile.mkdtemp(), + measurements=self.measurements, + ) + + recorder.neural_state = self.neural_state.copy() + + # Test correct keys + recorder._check_keys() + + # Test missing keys + del recorder.neural_state["accuracy"] + assert_raises(KeyError, recorder._check_keys) + + # Test additional keys (should not raise an error) + recorder.neural_state["accuracy"] = [onp.random.uniform(size=(1,))] + recorder.neural_state["additional_key"] = [onp.random.uniform(size=(5,))] + recorder._check_keys() + + def test_instantiate_recorder(self): + """ + Test the instantiate_recorder method. + """ + recorder = JaxRecorder( + name="test", + storage_path=tempfile.mkdtemp(), + measurements=self.measurements, + ) + recorder.instantiate_recorder( + data_set=self.dummy_data_set, + model=self.model, + ntk_computation=self.ntk_computation, + ) + assert recorder.neural_state == {} + assert recorder._data_set == self.dummy_data_set + assert recorder._model == self.model + assert recorder._ntk_computation == self.ntk_computation + + # Test errors for missing data + recorder = JaxRecorder( + name="test", + storage_path=tempfile.mkdtemp(), + measurements=self.measurements, + ) + assert_raises( + AttributeError, + recorder.instantiate_recorder, + model=self.model, + ntk_computation=self.ntk_computation, + ) + + # Test errors for missing model + recorder = JaxRecorder( + name="test", + storage_path=tempfile.mkdtemp(), + measurements=self.measurements, + ) + assert_raises( + AttributeError, + recorder.instantiate_recorder, + data_set=self.dummy_data_set, + ntk_computation=self.ntk_computation, + ) + + # Test errors for missing ntk_computation + recorder = JaxRecorder( + name="test", + storage_path=tempfile.mkdtemp(), + measurements=self.measurements, + ) + assert_raises( + AttributeError, + recorder.instantiate_recorder, + data_set=self.dummy_data_set, + model=self.model, + ) + + def test_record(self): + """ + Test the record method. + """ + recorder = JaxRecorder( + name="test", + storage_path=tempfile.mkdtemp(), + measurements=self.measurements, + ) + recorder.instantiate_recorder( + data_set=self.dummy_data_set, + model=self.model, + ntk_computation=self.ntk_computation, + ) + + recorder.record(model=self.model, accuracy=[onp.array([0.5])], epoch=0) + + # Check if the neural state is updated + assert recorder.neural_state["accuracy"] == [onp.array([0.5])] + assert onp.shape(recorder.neural_state["predictions"]) == (1, 5, 10) + assert onp.shape(recorder.neural_state["targets"]) == (1, 5, 10) + assert onp.shape(recorder.neural_state["ntk"]) == (1, 5, 5) diff --git a/CI/unit_tests/training_recording/test_training_recording.py b/CI/unit_tests/training_recording/test_training_recording.py index 7ab3952..7547679 100644 --- a/CI/unit_tests/training_recording/test_training_recording.py +++ b/CI/unit_tests/training_recording/test_training_recording.py @@ -1,168 +1,174 @@ -""" -ZnNL: A Zincwarecode package. - -License -------- -This program and the accompanying materials are made available under the terms -of the Eclipse Public License v2.0 which accompanies this distribution, and is -available at https://www.eclipse.org/legal/epl-v20.html - -SPDX-License-Identifier: EPL-2.0 - -Copyright Contributors to the Zincwarecode Project. - -Contact Information -------------------- -email: zincwarecode@gmail.com -github: https://github.com/zincware -web: https://zincwarecode.com/ - -Citation --------- -If you use this module please cite us with: - -Summary -------- -""" - -import tempfile -from pathlib import Path - -import h5py as hf -import numpy as onp -from numpy import testing - -from znnl.training_recording import JaxRecorder - - -class TestModelRecording: - """ - Unit test suite for the model recording. - """ - - @classmethod - def setup_class(cls): - """ - Prepare the test suite. - """ - dummy_data = onp.random.uniform(size=(5, 2, 3)) - cls.dummy_data_set = {"inputs": dummy_data, "targets": dummy_data} - - def test_instantiation(self): - """ - Test the instantiation of the recorder. - """ - - recorder = JaxRecorder( - loss=True, - accuracy=True, - ntk=True, - covariance_ntk=True, - magnitude_ntk=True, - entropy=True, - magnitude_entropy=True, - magnitude_variance=True, - covariance_entropy=True, - eigenvalues=True, - trace=True, - loss_derivative=True, - network_predictions=True, - ) - recorder.instantiate_recorder(data_set=self.dummy_data_set) - _exclude_list = [ - "_accuracy_fn", - "_loss_fn", - "update_rate", - "name", - "storage_path", - "chunk_size", - "flatten_ntk", - ] - for key, val in vars(recorder).items(): - if key[0] != "_" and key not in _exclude_list: - assert val is True - if key == "update_rate": - assert val == 1 - elif key.split("_")[-1] == "array:": - assert val == [] - elif key == "_selected_properties": - pass - - def test_data_dump(self): - """ - Test that data is dumped correctly. - """ - with tempfile.TemporaryDirectory() as directory: - recorder = JaxRecorder( - storage_path=directory, - name="my_recorder", - loss=True, - accuracy=False, - ntk=False, - covariance_ntk=True, - magnitude_ntk=True, - entropy=False, - magnitude_entropy=False, - magnitude_variance=False, - covariance_entropy=False, - eigenvalues=False, - ) - recorder.instantiate_recorder(data_set=self.dummy_data_set) - - # Add some dummy data. - test_data = onp.random.uniform(size=(200,)) - recorder._loss_array = test_data.tolist() - - recorder.dump_records() # dump to disk - - # Check that the dump worked. - assert Path(f"{directory}/my_recorder.h5").exists() - with hf.File(f"{directory}/my_recorder.h5", "r") as db: - testing.assert_almost_equal(db["loss"], test_data, decimal=7) - - def test_overwriting(self): - """ - Test the overwrite function. - """ - recorder = JaxRecorder( - loss=False, accuracy=False, ntk=True, entropy=False, eigenvalues=False - ) - recorder.instantiate_recorder(data_set=self.dummy_data_set) - - # Populate the arrays deliberately. - recorder._ntk_array = onp.random.uniform(size=(10, 5, 5)).tolist() - assert onp.sum(recorder._ntk_array) != 0.0 # check the data is there - - # Check normal resizing on instantiation. - recorder.instantiate_recorder(data_set=self.dummy_data_set, overwrite=False) - assert onp.shape(recorder._ntk_array) == (10, 5, 5) - - # Test overwriting. - recorder.instantiate_recorder(data_set=self.dummy_data_set, overwrite=True) - assert recorder._ntk_array == [] - - def test_magnitude_variance(self): - """ - Test the magnitude variance function. - """ - recorder = JaxRecorder( - loss=False, - accuracy=False, - ntk=False, - entropy=False, - magnitude_variance=True, - eigenvalues=False, - ) - recorder.instantiate_recorder(data_set=self.dummy_data_set) - - # Create some test data. - data = onp.random.uniform(1.0, 2.0, size=(100)) - ntk = onp.eye(100) * data - # calculate the magnitude variance - recorder._update_magnitude_variance(parsed_data={"ntk": ntk}) - # calculate the expected variance - expected_variance = onp.var(onp.sqrt(data) / onp.sqrt(data).mean()) - # check that the variance is correct - testing.assert_almost_equal( - recorder._magnitude_variance_array, expected_variance - ) +# """ +# ZnNL: A Zincwarecode package. + +# License +# ------- +# This program and the accompanying materials are made available under the terms +# of the Eclipse Public License v2.0 which accompanies this distribution, and is +# available at https://www.eclipse.org/legal/epl-v20.html + +# SPDX-License-Identifier: EPL-2.0 + +# Copyright Contributors to the Zincwarecode Project. + +# Contact Information +# ------------------- +# email: zincwarecode@gmail.com +# github: https://github.com/zincware +# web: https://zincwarecode.com/ + +# Citation +# -------- +# If you use this module please cite us with: + +# Summary +# ------- +# """ + +# import tempfile +# from pathlib import Path + +# import h5py as hf +# import numpy as onp +# from numpy import testing + +# from znnl.training_recording import JaxRecorder + + +# class TestModelRecording: +# """ +# Unit test suite for the model recording. +# """ + +# @classmethod +# def setup_class(cls): +# """ +# Prepare the test suite. +# """ +# dummy_data = onp.random.uniform(size=(5, 2, 3)) +# cls.dummy_data_set = {"inputs": dummy_data, "targets": dummy_data} + +# def test_instantiation(self): +# """ +# Test the instantiation of the recorder. +# """ + +# recorder = JaxRecorder( +# loss=True, +# accuracy=True, +# ntk=True, +# covariance_ntk=True, +# magnitude_ntk=True, +# entropy=True, +# magnitude_entropy=True, +# magnitude_variance=True, +# covariance_entropy=True, +# eigenvalues=True, +# trace=True, +# loss_derivative=True, +# network_predictions=True, +# ) +# recorder.instantiate_recorder(data_set=self.dummy_data_set) +# _exclude_list = [ +# "_accuracy_fn", +# "_loss_fn", +# "update_rate", +# "name", +# "storage_path", +# "chunk_size", +# "flatten_ntk", +# ] +# for key, val in vars(recorder).items(): +# if key[0] != "_" and key not in _exclude_list: +# assert val is True +# if key == "update_rate": +# assert val == 1 +# elif key.split("_")[-1] == "array:": +# assert val == [] +# elif key == "_selected_properties": +# pass + +# def test_data_dump(self): +# """ +# Test that data is dumped correctly. +# """ +# with tempfile.TemporaryDirectory() as directory: +# recorder = JaxRecorder( +# storage_path=directory, +# name="my_recorder", +# loss=True, +# accuracy=False, +# ntk=False, +# covariance_ntk=True, +# magnitude_ntk=True, +# entropy=False, +# magnitude_entropy=False, +# magnitude_variance=False, +# covariance_entropy=False, +# eigenvalues=False, +# ) +# recorder.instantiate_recorder( +# data_set=self.dummy_data_set, ntk_computation=[] +# ) + +# # Add some dummy data. +# test_data = onp.random.uniform(size=(200,)) +# recorder._loss_array = test_data.tolist() + +# recorder.dump_records() # dump to disk + +# # Check that the dump worked. +# assert Path(f"{directory}/my_recorder.h5").exists() +# with hf.File(f"{directory}/my_recorder.h5", "r") as db: +# testing.assert_almost_equal(db["loss"], test_data, decimal=7) + +# def test_overwriting(self): +# """ +# Test the overwrite function. +# """ +# recorder = JaxRecorder( +# loss=False, accuracy=False, ntk=True, entropy=False, eigenvalues=False +# ) +# recorder.instantiate_recorder(data_set=self.dummy_data_set, ntk_computation=[]) + +# # Populate the arrays deliberately. +# recorder._ntk_array = onp.random.uniform(size=(10, 5, 5)).tolist() +# assert onp.sum(recorder._ntk_array) != 0.0 # check the data is there + +# # Check normal resizing on instantiation. +# recorder.instantiate_recorder( +# data_set=self.dummy_data_set, overwrite=False, ntk_computation=[] +# ) +# assert onp.shape(recorder._ntk_array) == (10, 5, 5) + +# # Test overwriting. +# recorder.instantiate_recorder( +# data_set=self.dummy_data_set, overwrite=True, ntk_computation=[] +# ) +# assert recorder._ntk_array == [] + +# def test_magnitude_variance(self): +# """ +# Test the magnitude variance function. +# """ +# recorder = JaxRecorder( +# loss=False, +# accuracy=False, +# ntk=False, +# entropy=False, +# magnitude_variance=True, +# eigenvalues=False, +# ) +# recorder.instantiate_recorder(data_set=self.dummy_data_set, ntk_computation=[]) + +# # Create some test data. +# data = onp.random.uniform(1.0, 2.0, size=(100)) +# ntk = onp.eye(100) * data +# # calculate the magnitude variance +# recorder._update_magnitude_variance(parsed_data={"ntk": ntk}) +# # calculate the expected variance +# expected_variance = onp.var(onp.sqrt(data) / onp.sqrt(data).mean()) +# # check that the variance is correct +# testing.assert_almost_equal( +# recorder._magnitude_variance_array, expected_variance +# ) diff --git a/CI/unit_tests/utils/test_matrix_utils.py b/CI/unit_tests/utils/test_matrix_utils.py index bc73ded..2ce2b22 100644 --- a/CI/unit_tests/utils/test_matrix_utils.py +++ b/CI/unit_tests/utils/test_matrix_utils.py @@ -54,7 +54,7 @@ def test_unscaled_eigenvalues(self): values, vectors = compute_eigensystem(matrix, normalize=False) - assert_array_equal(np.real(values), [1.0, 1.0]) + assert_array_almost_equal(np.real(values), [1.0, 1.0]) def test_scaled_eigenvalues(self): """ diff --git a/examples/CIFAR10.ipynb b/examples/CIFAR10.ipynb index 1103245..94c3461 100644 --- a/examples/CIFAR10.ipynb +++ b/examples/CIFAR10.ipynb @@ -254,7 +254,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.6" + "version": "3.10.12" } }, "nbformat": 4, diff --git a/examples/Computing-Collective-Variables.ipynb b/examples/Computing-Collective-Variables.ipynb index c5500bb..23928ea 100644 --- a/examples/Computing-Collective-Variables.ipynb +++ b/examples/Computing-Collective-Variables.ipynb @@ -29,13 +29,17 @@ }, "outputs": [], "source": [ - "import os\n", - "os.environ['CUDA_VISIBLE_DEVICES'] = '-1'\n", + "# import os\n", + "# os.environ['CUDA_VISIBLE_DEVICES'] = '-1'\n", "\n", "import znnl as nl\n", "from neural_tangents import stax\n", "import optax\n", "\n", + "from papyrus.measurements import (\n", + " Loss, Accuracy, NTKTrace, NTKEntropy, NTK, NTKSelfEntropy, NTKEigenvalues, LossDerivative,\n", + ")\n", + "\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import pandas as pd\n", @@ -151,10 +155,24 @@ " nt_module=dense_network,\n", " optimizer=optax.adam(learning_rate=0.005),\n", " input_shape=(9,),\n", - " batch_size=314\n", ")" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "ec88fd2a", + "metadata": {}, + "outputs": [], + "source": [ + "ntk_computation = nl.analysis.JAXNTKComputation(\n", + " apply_fn=fuel_model.ntk_apply_fn, \n", + " batch_size=314,\n", + ")\n", + "\n", + "loss_derivative_computation = nl.analysis.LossDerivative(loss_fn=nl.loss_functions.LPNormLoss(order=2))" + ] + }, { "cell_type": "markdown", "id": "a6a9fbe0-def2-4bab-a808-8858ab2aa5e9", @@ -172,7 +190,7 @@ "- NTK\n", "- NTK Eigenvalues\n", "- Entropy of the NTK\n", - "- Magnitude Variance of the NTK\n", + "- Self-Entropy of the NTK\n", "- Trace of the NTK\n", "- Frobenius norm of the Loss Derivative\n", "\n", @@ -188,27 +206,40 @@ "source": [ "train_recorder = nl.training_recording.JaxRecorder(\n", " name=\"train_recorder\",\n", - " loss=True,\n", - " ntk=True,\n", - " covariance_entropy=True,\n", - " eigenvalues=True,\n", - " magnitude_variance=True, \n", - " trace=True,\n", - " loss_derivative=True,\n", - " update_rate=1\n", + " measurements=[\n", + " Loss(name=\"loss\", apply_fn=nl.loss_functions.LPNormLoss(order=2)),\n", + " Accuracy(name=\"accuracy\", apply_fn=nl.accuracy_functions.LabelAccuracy()),\n", + " NTKTrace(name=\"ntk_trace\"),\n", + " NTKEntropy(name=\"ntk_entropy\"),\n", + " NTK(name=\"ntk\"),\n", + " NTKSelfEntropy(name=\"ntk_self_entropy\"),\n", + " NTKEigenvalues(name=\"ntk_eigenvalues\"),\n", + " LossDerivative(name=\"loss_derivative\", apply_fn=loss_derivative_computation.calculate),\n", + " ],\n", + " storage_path=\".\",\n", + " update_rate=1, \n", + " chunk_size=1e5\n", ")\n", "train_recorder.instantiate_recorder(\n", - " data_set=data_generator.train_ds\n", + " data_set=data_generator.train_ds, \n", + " ntk_computation=ntk_computation, \n", + " model=fuel_model\n", ")\n", "\n", "\n", "test_recorder = nl.training_recording.JaxRecorder(\n", " name=\"test_recorder\",\n", - " loss=True,\n", - " update_rate=1\n", + " measurements=[\n", + " Loss(name=\"loss\", apply_fn=nl.loss_functions.LPNormLoss(order=2)),\n", + " Accuracy(name=\"accuracy\", apply_fn=nl.accuracy_functions.LabelAccuracy()),\n", + " ],\n", + " storage_path=\".\",\n", + " update_rate=1, \n", + " chunk_size=1e5\n", ")\n", "test_recorder.instantiate_recorder(\n", - " data_set=data_generator.test_ds\n", + " data_set=data_generator.test_ds, \n", + " model=fuel_model\n", ")" ] }, @@ -281,8 +312,8 @@ "metadata": {}, "outputs": [], "source": [ - "train_report = train_recorder.gather_recording()\n", - "test_report = test_recorder.gather_recording()" + "train_report = train_recorder.gather()\n", + "test_report = test_recorder.gather()" ] }, { @@ -292,8 +323,8 @@ "metadata": {}, "outputs": [], "source": [ - "plt.plot(train_report.loss, 'o', mfc='None', label=\"Train\")\n", - "plt.plot(test_report.loss, 'o', mfc='None', label=\"Train\")\n", + "plt.plot(train_report[\"loss\"], 'o', mfc='None', label=\"Train\")\n", + "plt.plot(test_report[\"loss\"], 'o', mfc='None', label=\"Test\")\n", "\n", "plt.xlabel(\"Epoch\")\n", "plt.ylabel(\"Loss\")\n", @@ -309,7 +340,7 @@ "metadata": {}, "outputs": [], "source": [ - "plt.plot(train_report.covariance_entropy, 'o', mfc='None', label=\"Entropy\")\n", + "plt.plot(train_report['ntk_entropy'], 'o', mfc='None', label=\"Entropy\")\n", "plt.xlabel(\"Epoch\")\n", "plt.ylabel(\"Entropy\")\n", "plt.legend()\n", @@ -323,7 +354,7 @@ "metadata": {}, "outputs": [], "source": [ - "plt.plot(train_report.magnitude_variance, 'o', mfc='None', label=\"Magnitude Variance\")\n", + "plt.plot(train_report['ntk_self_entropy'], 'o', mfc='None', label=\"Magnitude Variance\")\n", "plt.xlabel(\"Epoch\")\n", "plt.ylabel(\"Magnitude Variance\")\n", "plt.legend()\n", @@ -333,25 +364,32 @@ { "cell_type": "code", "execution_count": null, - "id": "0d3813e3", + "id": "6d43257c-defc-4f1e-816a-ebe1ae79e7ca", "metadata": {}, "outputs": [], "source": [ - "train_report.eigenvalues.shape" + "plt.plot(train_report['ntk_trace'], 'o', mfc='None', label=\"Trace\")\n", + "plt.xlabel(\"Epoch\")\n", + "plt.ylabel(\"Trace\")\n", + "plt.legend()\n", + "plt.show()" ] }, { "cell_type": "code", "execution_count": null, - "id": "6d43257c-defc-4f1e-816a-ebe1ae79e7ca", + "id": "2ff7f726", "metadata": {}, "outputs": [], "source": [ - "plt.plot(train_report.trace, 'o', mfc='None', label=\"Trace\")\n", + "plt.plot(np.array(train_report['ntk_eigenvalues'])[:,:,0], 'o', mfc='None', label=\"Largest EV\")\n", + "plt.plot(np.array(train_report['ntk_eigenvalues'])[:,:,1], 'o', mfc='None', label=\"2nd Largest EV\")\n", + "plt.plot(np.array(train_report['ntk_eigenvalues'])[:,:,2], 'o', mfc='None', label=\"3rd Largest EV\")\n", + "plt.plot(np.array(train_report['ntk_eigenvalues'])[:,:,3], 'o', mfc='None', label=\"4th Largest EV\")\n", "plt.xlabel(\"Epoch\")\n", - "plt.ylabel(\"Trace\")\n", - "plt.legend()\n", - "plt.show()" + "plt.ylabel(\"Eigenvalues\")\n", + "plt.yscale(\"log\")\n", + "plt.legend()" ] }, { @@ -364,7 +402,7 @@ "calculate_l_pq_norm = nl.utils.matrix_utils.calculate_l_pq_norm\n", "\n", "l_pq_norms = np.array([\n", - " calculate_l_pq_norm(i) for i in train_report.loss_derivative\n", + " calculate_l_pq_norm(i) for i in train_report[\"loss_derivative\"]\n", "])\n", "\n", "plt.plot(\n", diff --git a/examples/Contrastive-Loss.ipynb b/examples/Contrastive-Loss.ipynb index 3035d70..a13117a 100644 --- a/examples/Contrastive-Loss.ipynb +++ b/examples/Contrastive-Loss.ipynb @@ -32,7 +32,11 @@ "import optax\n", "from neural_tangents import stax\n", "import matplotlib.pyplot as plt\n", - "import pandas as pd" + "import pandas as pd\n", + "\n", + "from papyrus.measurements import (\n", + " Loss, Accuracy, NTKTrace, NTKEntropy, NTK, NTKSelfEntropy, NTKEigenvalues\n", + ")" ] }, { @@ -97,9 +101,8 @@ " \n", "model = znnl.models.FlaxModel(\n", " flax_module=Architecture(),\n", - " optimizer=optax.adam(learning_rate=0.01),\n", + " optimizer=optax.adam(learning_rate=0.005),\n", " input_shape=input_shape,\n", - " batch_size=10,\n", " seed=0,\n", ")" ] @@ -118,6 +121,27 @@ "One for each potential and one to track the NTK and according observables. " ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "44229b8d", + "metadata": {}, + "outputs": [], + "source": [ + "loss_fn = znnl.loss_functions.ContrastiveIsolatedPotentialLoss(\n", + " attractive_pot_fn=znnl.loss_functions.MeanPowerLoss(order=2), \n", + " repulsive_pot_fn=znnl.loss_functions.ExponentialRepulsionLoss(), \n", + " external_pot_fn=znnl.loss_functions.ExternalPotential(), \n", + " turn_off_attractive_potential=False,\n", + " turn_off_repulsive_potential=False,\n", + " turn_off_external_potential=False,\n", + " )\n", + "\n", + "def attractive_loss(point1, point2): return loss_fn.compute_losses(point1, point2)[0]\n", + "def repulsive_loss(point1, point2): return loss_fn.compute_losses(point1, point2)[1]\n", + "def external_loss(point1, point2): return loss_fn.compute_losses(point1, point2)[2]" + ] + }, { "cell_type": "code", "execution_count": null, @@ -129,47 +153,61 @@ "source": [ "attractive_recorder = znnl.training_recording.JaxRecorder(\n", " name=\"attractive_recorder\",\n", - " loss=True, \n", + " storage_path='.',\n", + " # loss=True, \n", + " measurements=[Loss(apply_fn=attractive_loss)],\n", " update_rate=1, \n", - " chunk_size=1e10 # Big Chunk-size to prevent saving the recordings.\n", ")\n", "attractive_recorder.instantiate_recorder(\n", - " data_set=data_generator.train_ds\n", + " data_set=data_generator.train_ds, \n", + " model=model\n", ")\n", "\n", "repulsive_recorder = znnl.training_recording.JaxRecorder(\n", " name=\"repulsive_recorder\",\n", - " loss=True, \n", + " storage_path='.',\n", + " measurements=[Loss(apply_fn=repulsive_loss)],\n", " update_rate=1, \n", - " chunk_size=1e10 # Big Chunk-size to prevent saving the recordings.\n", ")\n", "repulsive_recorder.instantiate_recorder(\n", - " data_set=data_generator.train_ds\n", + " data_set=data_generator.train_ds, \n", + " model=model\n", ")\n", "\n", "external_recorder = znnl.training_recording.JaxRecorder(\n", " name=\"external_recorder\",\n", - " loss=True, \n", + " storage_path='.',\n", + " measurements=[Loss(apply_fn=external_loss)],\n", " update_rate=1, \n", - " chunk_size=1e10 # Big Chunk-size to prevent saving the recordings.\n", ")\n", "external_recorder.instantiate_recorder(\n", - " data_set=data_generator.train_ds\n", + " data_set=data_generator.train_ds, \n", + " model=model\n", ")\n", "\n", "ntk_recorder = znnl.training_recording.JaxRecorder(\n", " name=\"nrk_recorder\",\n", - " # trace=True, \n", - " # entropy=True,\n", - " # covariance_entropy=True, \n", + " storage_path='.',\n", + " measurements=[NTKTrace(), NTKEntropy()],\n", " update_rate=1, \n", " chunk_size=1e10 # Big Chunk-size to prevent saving the recordings.\n", ")\n", + "ntk_computation = znnl.analysis.JAXNTKComputation(\n", + " apply_fn=model.ntk_apply_fn, \n", + " batch_size=10,\n", + ")\n", "ntk_recorder.instantiate_recorder(\n", - " data_set=data_generator.train_ds\n", + " data_set=data_generator.train_ds, \n", + " ntk_computation=ntk_computation, \n", + " model=model\n", ")\n", "\n", - "recorders = [attractive_recorder, repulsive_recorder, external_recorder, ntk_recorder]" + "recorders = [\n", + " attractive_recorder, \n", + " repulsive_recorder, \n", + " external_recorder, \n", + " # ntk_recorder\n", + "]" ] }, { @@ -192,47 +230,16 @@ "source": [ "trainer = znnl.training_strategies.SimpleTraining(\n", " model=model,\n", - " loss_fn=znnl.loss_functions.ContrastiveIsolatedPotentialLoss(\n", - " attractive_pot_fn=znnl.loss_functions.MeanPowerLoss(order=2), \n", - " repulsive_pot_fn=znnl.loss_functions.ExponentialRepulsionLoss(), \n", - " external_pot_fn=znnl.loss_functions.ExternalPotential(), \n", - " turn_off_attractive_potential=False,\n", - " turn_off_repulsive_potential=False,\n", - " turn_off_external_potential=False,\n", - " ),\n", + " loss_fn=loss_fn,\n", " recorders=recorders, \n", " seed=0,\n", ")" ] }, - { - "cell_type": "markdown", - "id": "3621c13f-5216-4aa7-a512-a24d62a3a0b7", - "metadata": {}, - "source": [ - "### Get the correct loss functions into the recorders" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "e435076f-2122-4b82-8300-09bc5ec77943", - "metadata": {}, - "outputs": [], - "source": [ - "def attractive_loss(point1, point2): return trainer.loss_fn.compute_losses(point1, point2)[0]\n", - "def repulsive_loss(point1, point2): return trainer.loss_fn.compute_losses(point1, point2)[1]\n", - "def external_loss(point1, point2): return trainer.loss_fn.compute_losses(point1, point2)[2]\n", - "\n", - "attractive_recorder._loss_fn = attractive_loss\n", - "repulsive_recorder._loss_fn = repulsive_loss\n", - "external_recorder._loss_fn = external_loss" - ] - }, { "cell_type": "code", "execution_count": null, - "id": "6a69e21a-78f2-4a37-be09-fee188c4e6bc", + "id": "ea7f7cb9", "metadata": {}, "outputs": [], "source": [ @@ -284,11 +291,11 @@ "metadata": {}, "outputs": [], "source": [ - "attractive_results = attractive_recorder.gather_recording()\n", - "repulsive_results = repulsive_recorder.gather_recording()\n", - "external_results = external_recorder.gather_recording()\n", + "attractive_results = attractive_recorder.gather()\n", + "repulsive_results = repulsive_recorder.gather()\n", + "external_results = external_recorder.gather()\n", "\n", - "ntk_results = ntk_recorder.gather_recording()" + "ntk_results = ntk_recorder.gather()" ] }, { @@ -308,9 +315,9 @@ "source": [ "plt.figure(figsize=(8, 5))\n", "\n", - "plt.plot(attractive_results.loss, label=\"attractive\")\n", - "plt.plot(repulsive_results.loss, label=\"repulsive\")\n", - "plt.plot(external_results.loss, label=\"external\")\n", + "plt.plot(attractive_results['loss'], label=\"attractive\")\n", + "plt.plot(repulsive_results['loss'], label=\"repulsive\")\n", + "plt.plot(external_results['loss'], label=\"external\")\n", "\n", "plt.legend()\n", "plt.yscale('log')\n", @@ -400,7 +407,6 @@ " flax_module=Architecture(),\n", " optimizer=optax.adam(learning_rate=0.001),\n", " input_shape=input_shape,\n", - " batch_size=20,\n", " seed=0,\n", ")" ] @@ -422,37 +428,54 @@ "source": [ "train_recorder = znnl.training_recording.JaxRecorder(\n", " name=\"train_recorder\",\n", - " loss=True, \n", + " storage_path='.',\n", + " # loss=True, \n", + " measurements=[\n", + " Loss(apply_fn=znnl.loss_functions.ContrastiveInfoNCELoss(temperature=0.05))\n", + " ],\n", " update_rate=1, \n", - " chunk_size=1e10 # Big Chunk-size to prevent saving the recordings.\n", ")\n", "train_recorder.instantiate_recorder(\n", - " data_set=data_generator.train_ds\n", + " data_set=data_generator.train_ds, \n", + " model=model\n", ")\n", "\n", "test_recorder = znnl.training_recording.JaxRecorder(\n", " name=\"test_recorder\",\n", - " loss=True, \n", + " storage_path='.',\n", + " measurements=[\n", + " Loss(apply_fn=znnl.loss_functions.ContrastiveInfoNCELoss(temperature=0.05))\n", + " ],\n", " update_rate=1, \n", " chunk_size=1e10 # Big Chunk-size to prevent saving the recordings.\n", ")\n", "test_recorder.instantiate_recorder(\n", - " data_set=data_generator.test_ds\n", + " data_set=data_generator.test_ds, \n", + " model=model\n", ")\n", "\n", "ntk_recorder = znnl.training_recording.JaxRecorder(\n", " name=\"nrk_recorder\",\n", - " # trace=True, \n", - " # entropy=True,\n", - " # covariance_entropy=True,\n", + " storage_path='.',\n", + " measurements=[NTKTrace(), NTKEntropy()],\n", " update_rate=1, \n", " chunk_size=1e10 # Big Chunk-size to prevent saving the recordings.\n", ")\n", + "ntk_computation = znnl.analysis.JAXNTKComputation(\n", + " apply_fn=model.ntk_apply_fn, \n", + " batch_size=10,\n", + ")\n", "ntk_recorder.instantiate_recorder(\n", - " data_set=data_generator.train_ds\n", + " data_set=data_generator.train_ds, \n", + " model=model,\n", + " ntk_computation=ntk_computation\n", ")\n", "\n", - "recorders = [train_recorder, test_recorder, ntk_recorder]" + "recorders = [\n", + " train_recorder, \n", + " test_recorder, \n", + " # ntk_recorder\n", + "]" ] }, { @@ -518,10 +541,10 @@ "metadata": {}, "outputs": [], "source": [ - "train_results = train_recorder.gather_recording()\n", - "test_results = test_recorder.gather_recording()\n", + "train_results = train_recorder.gather()\n", + "test_results = test_recorder.gather()\n", "\n", - "ntk_results = ntk_recorder.gather_recording()" + "ntk_results = ntk_recorder.gather()" ] }, { @@ -541,8 +564,8 @@ "source": [ "plt.figure(figsize=(8, 5))\n", "\n", - "plt.plot(train_results.loss, label=\"train\")\n", - "plt.plot(test_results.loss, label=\"test\")\n", + "plt.plot(train_results['loss'], label=\"train\")\n", + "plt.plot(test_results['loss'], label=\"test\")\n", "\n", "plt.legend()\n", "plt.yscale('log')\n", @@ -595,6 +618,14 @@ "metadata": {}, "outputs": [], "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "afb909f3", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/examples/Neural-Mutual-Information.ipynb b/examples/Neural-Mutual-Information.ipynb new file mode 100644 index 0000000..3025b5d --- /dev/null +++ b/examples/Neural-Mutual-Information.ipynb @@ -0,0 +1,404 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Computing the Neural Mutual Information (MI) \n", + "\n", + "In this notebook we will show how to compute the Neural Mutual Information (NMI) between classes of data during training. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# import os\n", + "# os.environ['CUDA_VISIBLE_DEVICES'] = '-1'\n", + "\n", + "import znnl as nl\n", + "from flax import linen as nn\n", + "import optax\n", + "\n", + "from papyrus.measurements import (\n", + " Loss, Accuracy, NTKEntropy\n", + ")\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import pandas as pd\n", + "\n", + "import jax\n", + "jax.default_backend()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For the demo of ZnNL, we will reduce the number of data points used for training and computing the Mutual Informtaion\n", + "To scale the computation, just increase the selected number of data points." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "num_train = 100\n", + "num_nmi_per_class = 2" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Data generators\n", + "\n", + "For the sake of covereage, we will look at the NTK properties of the Fuel data set for a small model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "data_generator = nl.data.MNISTGenerator(num_train)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Networks and Models\n", + "\n", + "Now we can define the network architectures for which we will compute the NTK of the data.\n", + "\n", + "The batch size defined in the model class refers to the batching in the NTK calculation. When calculating the NTK, the number of data points used in that calculation must be an integer mutliple of the batch size. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class DenseModule(nn.Module):\n", + " \"\"\"\n", + " Simple CNN module.\n", + " \"\"\"\n", + " @nn.compact\n", + " def __call__(self, x):\n", + " x = x.reshape((x.shape[0], -1)) # flatten\n", + " x = nn.Dense(features=32)(x)\n", + " x = nn.relu(x)\n", + " x = nn.Dense(features=32)(x)\n", + " x = nn.relu(x)\n", + " x = nn.Dense(features=10)(x)\n", + " return x" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model = nl.models.FlaxModel(\n", + " flax_module=DenseModule(),\n", + " optimizer=optax.sgd(learning_rate=0.005, momentum=0.9),\n", + " input_shape=(1, 28, 28, 1),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Recording \n", + "\n", + "We will record the loss and accuracy of the train and test data sets during training to see how well the model is learning." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "train_recorder = nl.training_recording.JaxRecorder(\n", + " name=\"train_recorder\",\n", + " measurements=[\n", + " Loss(name=\"loss\", apply_fn=nl.loss_functions.CrossEntropyLoss()),\n", + " Accuracy(name=\"accuracy\", apply_fn=nl.accuracy_functions.LabelAccuracy()),\n", + " ],\n", + " storage_path=\".\",\n", + " update_rate=1, \n", + " chunk_size=1e5\n", + ")\n", + "train_recorder.instantiate_recorder(\n", + " data_set=data_generator.train_ds, \n", + " model=model\n", + ")\n", + "\n", + "\n", + "test_recorder = nl.training_recording.JaxRecorder(\n", + " name=\"test_recorder\",\n", + " measurements=[\n", + " Loss(name=\"loss\", apply_fn=nl.loss_functions.CrossEntropyLoss()),\n", + " Accuracy(name=\"accuracy\", apply_fn=nl.accuracy_functions.LabelAccuracy()),\n", + " ],\n", + " storage_path=\".\",\n", + " update_rate=1, \n", + " chunk_size=1e5\n", + ")\n", + "test_recorder.instantiate_recorder(\n", + " data_set=data_generator.test_ds, \n", + " model=model\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Computing the Neural MI\n", + "\n", + "In order to compute the Neural MI, we will need to compute the von Neumann Entropy of the NTK. \n", + "We will create a subset of the training data to compute the NTK." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "n_points = num_nmi_per_class\n", + "\n", + "mni_ds = {\"inputs\": [], \"targets\": []}\n", + "\n", + "for i in [0, 1, 8]:\n", + " idx = np.argmax(data_generator.train_ds['targets'], axis=1) == i\n", + " num_points = np.sum(idx)\n", + " print(f\"Number of data points for class {i}: {num_points} of which {n_points} will be selected\")\n", + "\n", + " # Select the first n_points of class i\n", + " mni_ds[\"inputs\"].extend(data_generator.train_ds['inputs'][idx][:n_points])\n", + " mni_ds[\"targets\"].extend(data_generator.train_ds['targets'][idx][:n_points])\n", + "\n", + "mni_ds = {k: np.array(v) for k, v in mni_ds.items()}\n", + "print(f\"Total number of data points to record the Mutual Information: {len(mni_ds['inputs'])}\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The Mutual Information is a quantity that measures the amount of information that one distribution has about another. \n", + "In our case, we are interested in the amount of information that that one class of data has about another.\n", + "\n", + "Since comparing all classes to all other classes is overly complicated, we will compare classes [0, 1, 8]." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ntk_combintaion_computation = nl.analysis.JAXNTKCombinations(\n", + " apply_fn=model.ntk_apply_fn, \n", + " class_labels=[0, 1, 8], # Selecting the classes to compute the Neural MI for\n", + " batch_size=10,\n", + ")\n", + "mni_recorder = nl.training_recording.JaxRecorder(\n", + " name=\"mni_recorder\",\n", + " measurements=[\n", + " NTKEntropy(name=\"ntk_entropy\", effective=False, normalize_eigenvalues=True),\n", + " ],\n", + " storage_path=\".\",\n", + " update_rate=1,\n", + " chunk_size=1e5\n", + ")\n", + "mni_recorder.instantiate_recorder(\n", + " data_set=mni_ds, \n", + " ntk_computation=ntk_combintaion_computation\n", + ")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "trainer = nl.training_strategies.SimpleTraining(\n", + " model=model, \n", + " loss_fn=nl.loss_functions.CrossEntropyLoss(),\n", + " recorders=[\n", + " train_recorder, \n", + " test_recorder, \n", + " mni_recorder\n", + " ],\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "batched_training_metrics = trainer.train_model(\n", + " train_ds=data_generator.train_ds, \n", + " test_ds=data_generator.test_ds,\n", + " batch_size=10,\n", + " epochs=100,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Checking the results" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "train_report = train_recorder.gather()\n", + "test_report = test_recorder.gather()\n", + "mni_report = mni_recorder.gather()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, axs = plt.subplots(1, 2, figsize=(10, 4))\n", + "\n", + "axs[0].plot(train_report[\"loss\"], label=\"train\")\n", + "axs[0].plot(test_report[\"loss\"], label=\"test\")\n", + "axs[0].set_yscale(\"log\")\n", + "axs[0].set_xlabel(\"Epoch\")\n", + "axs[0].set_ylabel(\"Loss\")\n", + "\n", + "axs[1].plot(train_report[\"accuracy\"], label=\"train\")\n", + "axs[1].plot(test_report[\"accuracy\"], label=\"test\")\n", + "axs[1].set_xlabel(\"Epoch\")\n", + "axs[1].set_ylabel(\"Accuracy\")\n", + "\n", + "plt.show()\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Compute the Nerual MI\n", + "\n", + "*To obtain results in the following part, you need to uncomment the `nmi_recorder` in when defining the `trainer` object above.*\n", + "\n", + "\n", + "Using the `JAXNTKCombinations` module, we obtain a set of entropy values for all the combinations of the classes. We can then compute the Neural MI using the entropy values.\n", + "\n", + "The mutual information two correlated subsystems is obtained by:\n", + "\n", + "$$I(X;Y) = S(X) + S(Y) - S(X,Y)$$\n", + "\n", + "where $S(X)$ is the entropy of the first subsystem, $S(Y)$ is the entropy of the second subsystem, and $S(X,Y)$ is the joint entropy of the two subsystems.\n", + "Using this formula, we can compute the Mutual Information of the classes of the data.\n", + "The value of $$I(X;Y)$$ will however, depend on the size of the entropy values. For that reason we will normalize the Mutual Information by the sum of the entropies of the two classes:\n", + "\n", + "$$I(X;Y) = \\frac{2 \\cdot I(X;Y)}{S(X) + S(Y)} \\in [0, 1]$$\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "entropies = np.array(mni_report['ntk_entropy'])\n", + "\n", + "print(f\"We obtain one entropie for each label combination: {entropies.shape}\")\n", + "print(f\"The label combinations are: {ntk_combintaion_computation.label_combinations}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "mni_norm = {\n", + " \"I(0, 1)\": 2* ( entropies[:, 0] + entropies[:, 1] - entropies[:, 3]) / (entropies[:, 0] + entropies[:, 1]),\n", + " \"I(0, 8)\": 2* ( entropies[:, 0] + entropies[:, 2] - entropies[:, 4]) / (entropies[:, 0] + entropies[:, 2]),\n", + " \"I(1, 8)\": 2* ( entropies[:, 1] + entropies[:, 2] - entropies[:, 5]) / (entropies[:, 1] + entropies[:, 2]),\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, axs = plt.subplots(1, 2, figsize=(10, 4), tight_layout=True)\n", + "\n", + "# Plot the Entropies \n", + "axs[0].plot(entropies[:, 0] , label=\"H(0)\")\n", + "axs[0].plot(entropies[:, 1] , label=\"H(1)\")\n", + "axs[0].plot(entropies[:, 2] , label=\"H(8)\")\n", + "axs[0].set_xlabel(\"Epoch\")\n", + "axs[0].set_ylabel(\"Entropy\")\n", + "axs[0].legend()\n", + "\n", + "for key, value in mni_norm.items():\n", + " axs[1].plot(value, label=key)\n", + "axs[1].set_xlabel(\"Epoch\")\n", + "axs[1].set_ylabel(\"Normalized MI\")\n", + "axs[1].legend()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "jax", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/ResNet-Example.ipynb b/examples/ResNet-Example.ipynb index e60e606..d274030 100644 --- a/examples/ResNet-Example.ipynb +++ b/examples/ResNet-Example.ipynb @@ -23,6 +23,10 @@ "\n", "import znnl as nl\n", "\n", + "from papyrus.measurements import (\n", + " Loss, Accuracy, NTKTrace, NTKEntropy, NTK, NTKSelfEntropy\n", + ")\n", + "\n", "import numpy as np\n", "import optax\n", "\n", @@ -147,8 +151,6 @@ "model = HuggingFaceFlaxModel(\n", " _model, \n", " optax.adam(learning_rate=1e-3),\n", - " store_on_device=False,\n", - " batch_size=2,\n", ")" ] }, @@ -161,18 +163,25 @@ "source": [ "train_recorder = nl.training_recording.JaxRecorder(\n", " name=\"train_recorder\",\n", - " loss=True,\n", - " accuracy=True,\n", - " ntk=True,\n", - " covariance_entropy=True,\n", - " magnitude_variance=True, \n", - " trace=True,\n", - " loss_derivative=True,\n", + " storage_path=\".\",\n", + " measurements=[\n", + " Loss(apply_fn=nl.loss_functions.CrossEntropyLoss()),\n", + " Accuracy(apply_fn=nl.accuracy_functions.LabelAccuracy()),\n", + " NTKTrace(name=\"ntk_trace\"),\n", + " NTKEntropy(name=\"ntk_entropy\"),\n", + " NTK(name=\"ntk\"),\n", + " NTKSelfEntropy(name=\"ntk_self_entropy\"),\n", + " ],\n", " update_rate=1, \n", - " chunk_size=1000,\n", + ")\n", + "ntk_computation = nl.analysis.JAXNTKComputation(\n", + " apply_fn=model.ntk_apply_fn, \n", + " batch_size=10, \n", ")\n", "train_recorder.instantiate_recorder(\n", - " data_set=data_generator.train_ds\n", + " data_set=data_generator.train_ds, \n", + " ntk_computation=ntk_computation,\n", + " model=model,\n", ")\n", "\n", "trainer = nl.training_strategies.SimpleTraining(\n", @@ -221,8 +230,9 @@ "metadata": {}, "outputs": [], "source": [ - "train_report = train_recorder.gather_recording()\n", - "num_params = jax.flatten_util.ravel_pytree(model.model_state.params)[0].shape" + "train_report = train_recorder.gather()\n", + "num_params = jax.flatten_util.ravel_pytree(model.model_state.params)[0].shape\n", + "print(f\"Number of parameters: {num_params}\")" ] }, { @@ -233,8 +243,8 @@ "outputs": [], "source": [ "plt.plot(batch_wise_training_metrics['train_losses'], label='train loss')\n", - "plt.plot(train_report.covariance_entropy, label=\"covariance_entropy\")\n", - "plt.plot(train_report.trace/num_params, label=\"trace\")\n", + "plt.plot(train_report['ntk_entropy'], label=\"covariance_entropy\")\n", + "plt.plot(train_report['ntk_trace'], label=\"trace\")\n", "plt.yscale(\"log\")\n", "plt.legend()\n", "plt.show()" @@ -301,7 +311,6 @@ "resnet = HuggingFaceFlaxModel(\n", " _resnet,\n", " optax.sgd(learning_rate=1e-4),\n", - " batch_size=3,\n", ")\n", "\n", "\n", @@ -313,12 +322,14 @@ "\n", "train_recorder = JaxRecorder(\n", " name=\"train_recorder\",\n", - " loss=True,\n", + " storage_path=\".\",\n", + " measurements=[Loss(apply_fn=CrossEntropyLoss())],\n", " update_rate=1, \n", " chunk_size=1e5\n", ")\n", "train_recorder.instantiate_recorder(\n", - " data_set=train_ds\n", + " data_set=train_ds, \n", + " model=resnet,\n", ")\n", "\n", "\n", @@ -363,15 +374,23 @@ "source": [ "# import matplotlib.pyplot as plt\n", "\n", - "# train_report = train_recorder.gather_recording()\n", + "# train_report = train_recorder.gather()\n", "\n", - "# plt.plot(train_report.loss, label=\"loss using train=False\")\n", + "# plt.plot(train_report['loss'], label=\"loss using train=False\")\n", "# plt.yscale(\"log\")\n", "# plt.plot(batched_loss['train_losses'], label=\"loss using train=True\")\n", "# plt.yscale(\"log\")\n", "# plt.title(\"Train Losses\")\n", "# plt.legend()" ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c9003dbb", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/examples/Using-Training-Strategies.ipynb b/examples/Using-Training-Strategies.ipynb index 4f559b6..25925b4 100644 --- a/examples/Using-Training-Strategies.ipynb +++ b/examples/Using-Training-Strategies.ipynb @@ -12,6 +12,9 @@ "\n", "import copy\n", "import znnl as nl\n", + "\n", + "from papyrus.measurements import Loss\n", + "\n", "import numpy as np\n", "\n", "import matplotlib.pyplot as plt\n", @@ -79,14 +82,12 @@ " nt_module=architecture,\n", " optimizer=optax.adam(learning_rate=0.02),\n", " input_shape=input_shape,\n", - " batch_size=10,\n", ")\n", "\n", "predictor_model = nl.models.NTModel(\n", " nt_module=architecture,\n", " optimizer=optax.adam(learning_rate=0.02),\n", " input_shape=input_shape,\n", - " batch_size=10,\n", ")" ] }, @@ -136,12 +137,14 @@ "source": [ "simple_recorder = nl.training_recording.JaxRecorder(\n", " name=\"simple_recorder\",\n", - " loss=True, \n", + " storage_path=\".\",\n", + " measurements=[Loss(apply_fn=nl.loss_functions.MeanPowerLoss(order=2))],\n", " update_rate=1, \n", " chunk_size=1e10 # Big Chunk-size to prevent saving the recordings.\n", ")\n", "simple_recorder.instantiate_recorder(\n", - " data_set=dataset\n", + " data_set=dataset, \n", + " model=predictor_model,\n", ")\n", "simple_trainer = nl.training_strategies.SimpleTraining(\n", " model=None,\n", @@ -170,13 +173,15 @@ "outputs": [], "source": [ "partitioned_recorder = nl.training_recording.JaxRecorder(\n", - " name=\"simple_recorder\",\n", - " loss=True, \n", + " name=\"part_recorder\",\n", + " storage_path=\".\",\n", + " measurements=[Loss(apply_fn=nl.loss_functions.MeanPowerLoss(order=2))],\n", " update_rate=1, \n", " chunk_size=1e10 # Big Chunk-size to prevent saving the recordings.\n", ")\n", "partitioned_recorder.instantiate_recorder(\n", - " data_set=dataset\n", + " data_set=dataset, \n", + " model=predictor_model,\n", ")\n", "partitioned_trainer = nl.training_strategies.PartitionedTraining(\n", " model=None,\n", @@ -203,13 +208,15 @@ "outputs": [], "source": [ "LaR_recorder = nl.training_recording.JaxRecorder(\n", - " name=\"simple_recorder\",\n", - " loss=True, \n", + " name=\"lar_recorder\",\n", + " storage_path=\".\",\n", + " measurements=[Loss(apply_fn=nl.loss_functions.MeanPowerLoss(order=2))],\n", " update_rate=1, \n", " chunk_size=1e10 # Big Chunk-size to prevent saving the recordings.\n", ")\n", "LaR_recorder.instantiate_recorder(\n", - " data_set=dataset\n", + " data_set=dataset, \n", + " model=predictor_model,\n", ")\n", "LaR_trainer = nl.training_strategies.LossAwareReservoir(\n", " model=None,\n", @@ -310,9 +317,9 @@ "metadata": {}, "outputs": [], "source": [ - "simple_report = simple_recorder.gather_recording()\n", - "pertitioned_report = partitioned_recorder.gather_recording()\n", - "LaR_report = LaR_recorder.gather_recording()" + "simple_report = simple_recorder.gather()\n", + "pertitioned_report = partitioned_recorder.gather()\n", + "LaR_report = LaR_recorder.gather()" ] }, { @@ -322,9 +329,9 @@ "metadata": {}, "outputs": [], "source": [ - "plt.plot(simple_report.loss, '-', mfc='None', label=\"SimpleTraining\")\n", - "plt.plot(pertitioned_report.loss, '-', mfc='None', label=\"PartitionedTraining\")\n", - "plt.plot(LaR_report.loss, '-', mfc='None', label=\"LossAwareReservoir\")\n", + "plt.plot(simple_report['loss'], '-', mfc='None', label=\"SimpleTraining\")\n", + "plt.plot(pertitioned_report['loss'], '-', mfc='None', label=\"PartitionedTraining\")\n", + "plt.plot(LaR_report['loss'], '-', mfc='None', label=\"LossAwareReservoir\")\n", "plt.xlabel(\"Epoch\")\n", "plt.ylabel(\"Loss\")\n", "plt.legend()\n", @@ -389,7 +396,6 @@ " nt_module=architecture,\n", " optimizer=optax.adam(learning_rate=0.02),\n", " input_shape=input_shape,\n", - " batch_size=10,\n", ")" ] }, @@ -410,12 +416,14 @@ "source": [ "pre_train_recorder = nl.training_recording.JaxRecorder(\n", " name=\"simple_recorder\",\n", - " loss=True, \n", + " storage_path=\".\",\n", + " measurements=[Loss(apply_fn=nl.loss_functions.MeanPowerLoss(order=2))],\n", " update_rate=1, \n", " chunk_size=1e10 # Big Chunk-size to prevent saving the recordings.\n", ")\n", "pre_train_recorder.instantiate_recorder(\n", - " data_set=pre_train_ds\n", + " data_set=pre_train_ds, \n", + " model=model,\n", ")\n", "pre_trainer = nl.training_strategies.SimpleTraining(\n", " model=model,\n", @@ -449,7 +457,7 @@ "metadata": {}, "outputs": [], "source": [ - "pre_train_report = pre_train_recorder.gather_recording()" + "pre_train_report = pre_train_recorder.gather()" ] }, { @@ -459,7 +467,7 @@ "metadata": {}, "outputs": [], "source": [ - "plt.plot(pre_train_report.loss, '-', mfc='None', label=\"Pre-training\")\n", + "plt.plot(pre_train_report['loss'], '-', mfc='None', label=\"Pre-training\")\n", "plt.xlabel(\"Epoch\")\n", "plt.ylabel(\"Loss\")\n", "plt.yscale(\"log\")\n", @@ -528,12 +536,14 @@ "source": [ "simple_recorder = nl.training_recording.JaxRecorder(\n", " name=\"simple_recorder\",\n", - " loss=True, \n", + " storage_path=\".\",\n", + " measurements=[Loss(apply_fn=nl.loss_functions.MeanPowerLoss(order=2))],\n", " update_rate=1, \n", " chunk_size=1e10 # Big Chunk-size to prevent saving the recordings.\n", ")\n", "simple_recorder.instantiate_recorder(\n", - " data_set=data_generator.train_ds\n", + " data_set=data_generator.train_ds, \n", + " model=model,\n", ")\n", "simple_trainer = nl.training_strategies.SimpleTraining(\n", " model=copy.deepcopy(model),\n", @@ -561,12 +571,14 @@ "source": [ "partitioned_recorder = nl.training_recording.JaxRecorder(\n", " name=\"simple_recorder\",\n", - " loss=True, \n", + " storage_path=\".\",\n", + " measurements=[Loss(apply_fn=nl.loss_functions.MeanPowerLoss(order=2))],\n", " update_rate=1, \n", " chunk_size=1e10 # Big Chunk-size to prevent saving the recordings.\n", ")\n", "partitioned_recorder.instantiate_recorder(\n", - " data_set=data_generator.train_ds\n", + " data_set=data_generator.train_ds, \n", + " model=model,\n", ")\n", "partitioned_trainer = nl.training_strategies.PartitionedTraining(\n", " model=copy.deepcopy(model),\n", @@ -592,12 +604,14 @@ "source": [ "LaR_recorder = nl.training_recording.JaxRecorder(\n", " name=\"simple_recorder\",\n", - " loss=True, \n", + " storage_path=\".\",\n", + " measurements=[Loss(apply_fn=nl.loss_functions.MeanPowerLoss(order=2))],\n", " update_rate=1, \n", " chunk_size=1e10 # Big Chunk-size to prevent saving the recordings.\n", ")\n", "LaR_recorder.instantiate_recorder(\n", - " data_set=data_generator.train_ds\n", + " data_set=data_generator.train_ds, \n", + " model=model,\n", ")\n", "LaR_trainer = nl.training_strategies.LossAwareReservoir(\n", " model=copy.deepcopy(model),\n", @@ -661,9 +675,9 @@ "metadata": {}, "outputs": [], "source": [ - "simple_report = simple_recorder.gather_recording()\n", - "pertitioned_report = partitioned_recorder.gather_recording()\n", - "LaR_report = LaR_recorder.gather_recording()" + "simple_report = simple_recorder.gather()\n", + "pertitioned_report = partitioned_recorder.gather()\n", + "LaR_report = LaR_recorder.gather()" ] }, { @@ -673,15 +687,23 @@ "metadata": {}, "outputs": [], "source": [ - "plt.plot(simple_report.loss, '-', mfc='None', label=\"SimpleTraining\")\n", - "plt.plot(pertitioned_report.loss, '-', mfc='None', label=\"PartitionedTraining\")\n", - "plt.plot(LaR_report.loss, '-', mfc='None', label=\"LossAwareReservoir\")\n", + "plt.plot(simple_report['loss'], '-', mfc='None', label=\"SimpleTraining\")\n", + "plt.plot(pertitioned_report['loss'], '-', mfc='None', label=\"PartitionedTraining\")\n", + "plt.plot(LaR_report['loss'], '-', mfc='None', label=\"LossAwareReservoir\")\n", "plt.yscale(\"log\")\n", "plt.xlabel(\"Epoch\")\n", "plt.ylabel(\"Loss\")\n", "plt.legend()\n", "plt.show()" ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c500d5b8", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { @@ -700,7 +722,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.6" + "version": "3.10.12" } }, "nbformat": 4, diff --git a/examples/Using-the-Data-Recorders.ipynb b/examples/Using-the-Data-Recorders.ipynb index 282ac97..0283a6e 100644 --- a/examples/Using-the-Data-Recorders.ipynb +++ b/examples/Using-the-Data-Recorders.ipynb @@ -27,7 +27,9 @@ "import matplotlib.pyplot as plt\n", "\n", "from neural_tangents import stax\n", - "import optax" + "import optax\n", + "\n", + "from papyrus.measurements import Loss, Accuracy" ] }, { @@ -91,15 +93,23 @@ "outputs": [], "source": [ "train_recorder = nl.training_recording.JaxRecorder(\n", - " name=\"train_recorder\",\n", - " loss=True,\n", - " accuracy=True,\n", - " update_rate=1\n", + " name=\"train_recorder\", # name of the recorder\n", + " storage_path=\".\", # where to save the data\n", + " measurements=[\n", + " Loss(apply_fn=nl.loss_functions.MeanPowerLoss(order=2)), \n", + " Accuracy(apply_fn=nl.accuracy_functions.LabelAccuracy())\n", + " ], # list of measurements to record\n", + " chunk_size=1e5, # number of samples to keep in memory before writing to disk\n", + " update_rate=1 # number of epochs between recording\n", ")\n", "test_recorder = nl.training_recording.JaxRecorder(\n", " name=\"test_recorder\",\n", - " loss=True,\n", - " accuracy=True,\n", + " storage_path=\".\",\n", + " measurements=[\n", + " Loss(apply_fn=nl.loss_functions.MeanPowerLoss(order=2)), \n", + " Accuracy(apply_fn=nl.accuracy_functions.LabelAccuracy())\n", + " ],\n", + " chunk_size=1e5,\n", " update_rate=10\n", ")" ] @@ -113,7 +123,13 @@ "source": [ "## Create Model and Prepare Recording\n", "\n", - "Before we train, we need to create a model and prepare the recorders. In this time, we add data and a model to the recorders. Note, any data can be added to the recorders here, even validation data." + "Before we train, we need to create a model and prepare the recorders. \n", + "\n", + "When creating the recorders, we defined the intervals and properties we want to record.\n", + "Here, in the `instantiation` method, we give the recorders data and the functions it needs for recording.\n", + "\n", + "The example below shows that the train recorder gets the train data and the test recorder gets the test data, however, both recorders get the same model, as they operate on the same model.\n", + "Note, any data can be added to the recorders here, even validation data.\n" ] }, { @@ -138,10 +154,12 @@ "outputs": [], "source": [ "train_recorder.instantiate_recorder(\n", - " data_set=data_generator.train_ds\n", + " data_set=data_generator.train_ds, \n", + " model=production_model\n", ")\n", "test_recorder.instantiate_recorder(\n", - " data_set=data_generator.test_ds\n", + " data_set=data_generator.test_ds, \n", + " model=production_model\n", ")" ] }, @@ -214,8 +232,8 @@ "metadata": {}, "outputs": [], "source": [ - "train_report = train_recorder.gather_recording()\n", - "test_report = test_recorder.gather_recording()" + "train_report = train_recorder.gather()\n", + "test_report = test_recorder.gather()" ] }, { @@ -225,8 +243,8 @@ "metadata": {}, "outputs": [], "source": [ - "plt.plot(train_report.loss, 'o', mfc='None', label=\"Train\")\n", - "plt.plot(np.linspace(0, 100, 10), test_report.loss, '.', mfc=\"None\", label=\"Test\")\n", + "plt.plot(train_report[\"loss\"], 'o', mfc='None', label=\"Train\")\n", + "plt.plot(np.linspace(0, 100, 10), test_report[\"loss\"], '.-', mfc=\"None\", label=\"Test\")\n", "plt.xlabel(\"Epoch\")\n", "plt.ylabel(\"Loss\")\n", "plt.legend()\n", @@ -240,8 +258,8 @@ "metadata": {}, "outputs": [], "source": [ - "plt.plot(train_report.accuracy, 'o', mfc='None', label=\"Train\")\n", - "plt.plot(np.linspace(0, 100, 10), test_report.accuracy, '.', mfc=\"None\", label=\"Test\")\n", + "plt.plot(train_report[\"accuracy\"], 'o', mfc='None', label=\"Train\")\n", + "plt.plot(np.linspace(0, 100, 10), test_report[\"accuracy\"], '.-', mfc=\"None\", label=\"Test\")\n", "plt.xlabel(\"Epoch\")\n", "plt.ylabel(\"Accuracy\")\n", "plt.legend()\n", @@ -265,8 +283,18 @@ "metadata": {}, "outputs": [], "source": [ - "train_recorder.dump_records()\n", - "test_recorder.dump_records()" + "train_recorder.store()\n", + "test_recorder.store()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2e1c46e2", + "metadata": {}, + "outputs": [], + "source": [ + "train_recorder.load()['loss'].shape" ] } ], @@ -286,7 +314,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.6" + "version": "3.10.12" } }, "nbformat": 4, diff --git a/requirements.txt b/requirements.txt index 71bb1d3..6ebee20 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,4 +14,5 @@ tensorflow>=2.16.1 jupyter>=1.0.0 transformers>=4.40.0 jax>=0.4.26 -jaxlib>=0.4.26 \ No newline at end of file +jaxlib>=0.4.26 +papyrus @ git+https://github.com/zincware/papyrus \ No newline at end of file diff --git a/znnl/agents/approximate_maximum_entropy.py b/znnl/agents/approximate_maximum_entropy.py index ecbf2ac..c44c25b 100644 --- a/znnl/agents/approximate_maximum_entropy.py +++ b/znnl/agents/approximate_maximum_entropy.py @@ -30,6 +30,7 @@ from znnl.agents.agent import Agent from znnl.analysis.entropy import EntropyAnalysis +from znnl.analysis.jax_ntk import JAXNTKComputation from znnl.data import DataGenerator from znnl.models import JaxModel from znnl.utils.prng import PRNGKey @@ -66,6 +67,9 @@ def __init__( self.target_set: np.ndarray self.target_indices: list + self.ntk_computation = JAXNTKComputation( + target_network.ntk_apply_fn, trace_axes=(-1,) + ) def _compute_entropy(self, dataset: np.ndarray): """ @@ -81,7 +85,7 @@ def _compute_entropy(self, dataset: np.ndarray): entropy : float Entropy pf the dataset. """ - ntk = self._compute_ntk(dataset) + ntk = self._compute_ntk({"inputs": dataset, "targets": None}) entropy_calculator = EntropyAnalysis(matrix=ntk) @@ -96,7 +100,9 @@ def _compute_ntk(self, dataset: np.ndarray): empirical_ntk : np.ndarray The empirical NTK matrix of the target network. """ - return self.target_network.compute_ntk(dataset)["empirical"] + return self._compute_ntk.compute_ntk( + {"params": self.target_network.model_state.params}, dataset + ) def build_dataset( self, target_size: int = None, visualize: bool = False, report: bool = True diff --git a/znnl/analysis/__init__.py b/znnl/analysis/__init__.py index 94701c5..d8b6c50 100644 --- a/znnl/analysis/__init__.py +++ b/znnl/analysis/__init__.py @@ -27,10 +27,18 @@ from znnl.analysis.eigensystem import EigenSpaceAnalysis from znnl.analysis.entropy import EntropyAnalysis +from znnl.analysis.jax_ntk import JAXNTKComputation +from znnl.analysis.jax_ntk_classwise import JAXNTKClassWise +from znnl.analysis.jax_ntk_combinations import JAXNTKCombinations +from znnl.analysis.jax_ntk_subsampling import JAXNTKSubsampling from znnl.analysis.loss_fn_derivative import LossDerivative __all__ = [ EntropyAnalysis.__name__, EigenSpaceAnalysis.__name__, LossDerivative.__name__, + JAXNTKComputation.__name__, + JAXNTKClassWise.__name__, + JAXNTKSubsampling.__name__, + JAXNTKCombinations.__name__, ] diff --git a/znnl/analysis/jax_ntk.py b/znnl/analysis/jax_ntk.py new file mode 100644 index 0000000..5109bd7 --- /dev/null +++ b/znnl/analysis/jax_ntk.py @@ -0,0 +1,172 @@ +""" +ZnNL: A Zincwarecode package. + +License +------- +This program and the accompanying materials are made available under the terms +of the Eclipse Public License v2.0 which accompanies this distribution, and is +available at https://www.eclipse.org/legal/epl-v20.html + +SPDX-License-Identifier: EPL-2.0 + +Copyright Contributors to the Zincwarecode Project. + +Contact Information +------------------- +email: zincwarecode@gmail.com +github: https://github.com/zincware +web: https://zincwarecode.com/ + +Citation +-------- +If you use this module please cite us with: + +Summary +------- +""" + +from abc import ABC +from typing import Callable, List, Optional + +import jax.numpy as np +import neural_tangents as nt +from papyrus.utils.matrix_utils import flatten_rank_4_tensor + + +class JAXNTKComputation(ABC): + """ + Class for computing the empirical Neural Tangent Kernel (NTK) using the + neural-tangents library (implemented in JAX). + """ + + def __init__( + self, + apply_fn: Callable, + batch_size: int = 10, + ntk_implementation: nt.NtkImplementation = None, + trace_axes: tuple = (), + store_on_device: bool = False, + flatten: bool = True, + data_keys: Optional[List[str]] = None, + ): + """ + Constructor the JAX NTK computation class. + + Parameters + ---------- + apply_fn : Callable + The function that applies the neural network to an input. + This function should be implemented using JAX. It should take in a + dictionary of parameters (and possibly other arguments) and return the + output of the neural network. + For models taking in `batch_stats` the apply function should look like:: + def apply_fn(params, x): + return model.apply( + params, x, train=False, mutable=['batch_stats'] + )[0] + batch_size : int + Size of batch to use in the NTk calculation. + ntk_implementation : Union[None, NtkImplementation] (default = None) + Implementation of the NTK computation. + The implementation depends on the trace_axes and the model + architecture. The default does automatically take into account the + trace_axes. For trace_axes=() the default is NTK_VECTOR_PRODUCTS, + for all other cases including trace_axes=(-1,) the default is + JACOBIAN_CONTRACTION. For more specific use cases, the user can + set the implementation manually. + Information about the implementation and specific requirements can be + found in the neural_tangents documentation. + trace_axes : Union[int, Sequence[int]] + Tracing over axes of the NTK. + The default value is trace_axes(-1,), which reduces the NTK to a tensor + of rank 2. + For a full NTK set trace_axes=(). + store_on_device : bool, default True + Whether to store the NTK on the device or not. + This should be set False for large NTKs that do not fit in GPU memory. + flatten : bool, default True + If True, the NTK shape is checked and flattened into a 2D matrix, if + required. + data_keys : List[str], default ["inputs", "targets"] + The keys used to define inputs and targets in the dataset. + These keys are used to extract values from the dataset dictionary in + the `compute_ntk` method. + Note that the first key has to refer the input data and the second key + to the targets / labels of the dataset. + """ + self.apply_fn = apply_fn + self.batch_size = batch_size + self.ntk_implementation = ntk_implementation + self.trace_axes = trace_axes + self.store_on_device = store_on_device + self.flatten = flatten + self.data_keys = data_keys + + self._ntk_shape = None + self._is_flattened = False + + # Set default data keys + if self.data_keys is None: + self.data_keys = ["inputs", "targets"] + + # Prepare NTK calculation + if self.ntk_implementation is None: + if trace_axes == (): + self.ntk_implementation = nt.NtkImplementation.NTK_VECTOR_PRODUCTS + else: + self.ntk_implementation = nt.NtkImplementation.JACOBIAN_CONTRACTION + self.empirical_ntk = nt.batch( + nt.empirical_ntk_fn( + f=apply_fn, + trace_axes=trace_axes, + implementation=self.ntk_implementation, + ), + batch_size=batch_size, + store_on_device=store_on_device, + ) + + def _check_shape(self, ntk: np.ndarray) -> np.ndarray: + """ + Check the shape of the NTK matrix and flatten it if required. + + Parameters + ---------- + ntk : np.ndarray + The NTK matrix. + + Returns + ------- + np.ndarray + The NTK matrix. + """ + self._ntk_shape = ntk.shape + if self.flatten and len(self._ntk_shape) > 2: + ntk, _ = flatten_rank_4_tensor(ntk) + self._is_flattened = True + return ntk + + def compute_ntk( + self, params: dict, dataset_i: dict, dataset_j: Optional[dict] = None + ) -> List[np.ndarray]: + """ + Compute the Neural Tangent Kernel (NTK) for the neural network. + + Parameters + ---------- + params : dict + The parameters of the neural network. + dataset_i : dict + The input dataset for the NTK computation. + dataset_j : Optional[dict] + Optional input dataset for the NTK computation. + + Returns + ------- + List[np.ndarray] + The NTK matrix. + """ + x_i = dataset_i[self.data_keys[0]] + x_j = dataset_j[self.data_keys[0]] if dataset_j is not None else None + ntk = self.empirical_ntk(x_i, x_j, params) + ntk = self._check_shape(ntk) + return [ntk] diff --git a/znnl/analysis/jax_ntk_classwise.py b/znnl/analysis/jax_ntk_classwise.py new file mode 100644 index 0000000..ad1640f --- /dev/null +++ b/znnl/analysis/jax_ntk_classwise.py @@ -0,0 +1,242 @@ +""" +ZnNL: A Zincwarecode package. + +License +------- +This program and the accompanying materials are made available under the terms +of the Eclipse Public License v2.0 which accompanies this distribution, and is +available at https://www.eclipse.org/legal/epl-v20.html + +SPDX-License-Identifier: EPL-2.0 + +Copyright Contributors to the Zincwarecode Project. + +Contact Information +------------------- +email: zincwarecode@gmail.com +github: https://github.com/zincware +web: https://zincwarecode.com/ + +Citation +-------- +If you use this module please cite us with: + +Summary +------- +""" + +from typing import Callable, List, Optional + +import jax.numpy as np +import neural_tangents as nt +from jax import random, vmap +from jax.tree_util import tree_map as jmap + +from znnl.analysis.jax_ntk import JAXNTKComputation + + +class JAXNTKClassWise(JAXNTKComputation): + """ + Class for computing the empirical Neural Tangent Kernel (NTK) using the + neural-tangents library (implemented in JAX) with class-wise subsampling. + + This class is a subclass of JAXNTKComputation and adds the functionality of + subsampling the data according to the classes before computing the NTK. + In this way, the NTK is computed for each class separately. + + Note + ---- + This class is only implemented for the computing the NTK of a single dataset. + This menas that axis 0 and 1 of the NTK matrix correspond to the same dataset. + More information can be found in the `compute_ntk` method. + """ + + def __init__( + self, + apply_fn: Callable, + batch_size: int = 10, + ntk_implementation: nt.NtkImplementation = None, + trace_axes: tuple = (), + store_on_device: bool = False, + flatten: bool = True, + data_keys: Optional[List[str]] = None, + ntk_size: int = None, + ): + """ + Constructor the JAX NTK computation class. + + Parameters + ---------- + apply_fn : Callable + The function that applies the neural network to an input. + This function should be implemented using JAX. It should take in a + dictionary of parameters (and possibly other arguments) and return the + output of the neural network. + For models taking in `batch_stats` the apply function should look like:: + def apply_fn(params, x): + return model.apply( + params, x, train=False, mutable=['batch_stats'] + )[0] + batch_size : int + Size of batch to use in the NTk calculation. + ntk_implementation : Union[None, NtkImplementation] (default = None) + Implementation of the NTK computation. + The implementation depends on the trace_axes and the model + architecture. The default does automatically take into account the + trace_axes. For trace_axes=() the default is NTK_VECTOR_PRODUCTS, + for all other cases including trace_axes=(-1,) the default is + JACOBIAN_CONTRACTION. For more specific use cases, the user can + set the implementation manually. + Information about the implementation and specific requirements can be + found in the neural_tangents documentation. + trace_axes : Union[int, Sequence[int]] + Tracing over axes of the NTK. + The default value is trace_axes(-1,), which reduces the NTK to a tensor + of rank 2. + For a full NTK set trace_axes=(). + store_on_device : bool, default True + Whether to store the NTK on the device or not. + This should be set False for large NTKs that do not fit in GPU memory. + flatten : bool, default True + If True, the NTK shape is checked and flattened into a 2D matrix, if + required. + data_keys : List[str], default ["inputs", "targets"] + The keys used to define inputs and targets in the dataset. + These keys are used to extract values from the dataset dictionary in + the `compute_ntk` method. + Note that the first key has to refer the input data and the second key + to the targets / labels of the dataset. + ntk_size : int (default = None) + Upper limit for the number of samples used for the NTK computation. + """ + super().__init__( + apply_fn=apply_fn, + batch_size=batch_size, + ntk_implementation=ntk_implementation, + trace_axes=trace_axes, + store_on_device=store_on_device, + flatten=flatten, + data_keys=data_keys, + ) + + self._sample_indices = None + self.ntk_size = ntk_size + + def _get_label_indices(self, dataset: dict) -> List[np.ndarray]: + """ + Group the data by label and return the indices of the samples to use for the + NTK computation. + + Parameters + ---------- + dataset : dict + The dataset containing the inputs and targets. + + Returns + ------- + sample_indices : dict + A dictionary containing the indices of the samples for each class, with + the class label as the key. + """ + targets = dataset[self.data_keys[1]] + + if len(targets.shape) > 1: + # If one-hot encoding is used, convert it to class labels + if targets.shape[1] > 1: + targets = np.argmax(targets, axis=1) + # If the targets are already class labels, squeeze the array + elif targets.shape[1] == 1: + targets = np.squeeze(targets, axis=1) + + unique_classes = np.unique(targets) + _indices = np.arange(targets.shape[0]) + sample_indices = {} + + for class_label in unique_classes: + # Create mask for samples of the current class + mask = targets == class_label + indices = np.compress(mask, _indices, axis=0) + if self.ntk_size is not None: + indices = indices[: self.ntk_size] + sample_indices[int(class_label)] = indices + + return sample_indices + + def _subsample_data(self, x: np.ndarray, sample_indices: dict) -> np.ndarray: + """ + Subsample the data based on indices. + + Parameters + ---------- + x : np.ndarray + The input data. + sample_indices : dict + The indices of the samples to use for the NTK computation. + + Returns + ------- + np.ndarray + The subsampled data. + """ + return jmap(lambda indices: np.take(x, indices, axis=0), sample_indices) + + def _compute_ntk(self, params: dict, x_i: np.ndarray) -> np.ndarray: + """ + Compute the NTK for the neural network. + + Parameters + ---------- + params : dict + The parameters of the neural network. + x_i : np.ndarray + The input to the neural network. + + Returns + ------- + np.ndarray + The NTK matrix. + """ + ntk = self.empirical_ntk(x_i, None, params) + ntk = self._check_shape(ntk) + return ntk + + def compute_ntk(self, params: dict, dataset_i: dict) -> List[np.ndarray]: + """ + Compute the Neural Tangent Kernel (NTK) for the neural network. + + Note + ---- + This method only accepts a single dataset for the NTK computation. This means + both axes of the NTK matrix correspond to the same dataset. + For that reason, this method only takes a single dataset as input. + + Parameters + ---------- + params : dict + The parameters of the neural network. + dataset_i : dict + The input dataset for the NTK computation. + + Returns + ------- + List[np.ndarray] + The NTK matrix. + """ + + self._sample_indices = self._get_label_indices(dataset_i) + + x_i = self._subsample_data(dataset_i[self.data_keys[0]], self._sample_indices) + + ntks = jmap(lambda x_i: self._compute_ntk(params, x_i), x_i) + + ntks = list(ntks.values()) + + # Get the maximum key in the sample indices i + max_key = max(self._sample_indices.keys()) + + # Fill in the missing classes with empty NTKs + for i in range(max_key): + if i not in self._sample_indices.keys(): + ntks.insert(i, np.zeros((0, 0))) + + return ntks diff --git a/znnl/analysis/jax_ntk_combinations.py b/znnl/analysis/jax_ntk_combinations.py new file mode 100644 index 0000000..510fd11 --- /dev/null +++ b/znnl/analysis/jax_ntk_combinations.py @@ -0,0 +1,315 @@ +""" +ZnNL: A Zincwarecode package. + +License +------- +This program and the accompanying materials are made available under the terms +of the Eclipse Public License v2.0 which accompanies this distribution, and is +available at https://www.eclipse.org/legal/epl-v20.html + +SPDX-License-Identifier: EPL-2.0 + +Copyright Contributors to the Zincwarecode Project. + +Contact Information +------------------- +email: zincwarecode@gmail.com +github: https://github.com/zincware +web: https://zincwarecode.com/ + +Citation +-------- +If you use this module please cite us with: + +Summary +------- +""" + +from itertools import combinations +from typing import Callable, List, Optional + +import jax.numpy as np +import neural_tangents as nt +from papyrus.utils.matrix_utils import flatten_rank_4_tensor, unflatten_rank_4_tensor + +from znnl.analysis.jax_ntk import JAXNTKComputation + + +class JAXNTKCombinations(JAXNTKComputation): + """ + Class for computing the empirical Neural Tangent Kernel (NTK) using the + neural-tangents library (implemented in JAX) of all possible class combinations. + + This can be understood in the following way: + For a dataset of n labels, and a given selection of labels (e.g. 0, 2), the NTK + will be computed for all possible combinations made from the selected labels. + This means that the NTKs for the combinations + (0, 0), (0, 2), (2, 0), (2, 2) and (0+2, 0+2) will be computed. + + Note + ---- + This class is only implemented for the computing the NTK of a single dataset. + This menas that axis 0 and 1 of the NTK matrix correspond to the same dataset. + More information can be found in the `compute_ntk` method. + """ + + def __init__( + self, + apply_fn: Callable, + class_labels: List[int], + batch_size: int = 10, + ntk_implementation: nt.NtkImplementation = None, + trace_axes: tuple = (), + store_on_device: bool = False, + flatten: bool = True, + data_keys: Optional[List[str]] = None, + ): + """ + Constructor the JAX NTK computation class with subsampling. + + Parameters + ---------- + apply_fn : Callable + The function that applies the neural network to an input. + This function should be implemented using JAX. It should take in a + dictionary of parameters (and possibly other arguments) and return the + output of the neural network. + For models taking in `batch_stats` the apply function should look like:: + def apply_fn(params, x): + return model.apply( + params, x, train=False, mutable=['batch_stats'] + )[0] + class_labels : List[int] + List of class labels to use for + batch_size : int + Size of batch to use in the NTk calculation. + ntk_implementation : Union[None, NtkImplementation] (default = None) + Implementation of the NTK computation. + The implementation depends on the trace_axes and the model + architecture. The default does automatically take into account the + trace_axes. For trace_axes=() the default is NTK_VECTOR_PRODUCTS, + for all other cases including trace_axes=(-1,) the default is + JACOBIAN_CONTRACTION. For more specific use cases, the user can + set the implementation manually. + Information about the implementation and specific requirements can be + found in the neural_tangents documentation. + trace_axes : Union[int, Sequence[int]] + Tracing over axes of the NTK. + The default value is trace_axes(-1,), which reduces the NTK to a tensor + of rank 2. + For a full NTK set trace_axes=(). + store_on_device : bool, default True + Whether to store the NTK on the device or not. + This should be set False for large NTKs that do not fit in GPU memory. + flatten : bool, default True + If True, the NTK shape is checked and flattened into a 2D matrix, if + required. + data_keys : List[str], default ["inputs", "targets"] + The keys used to define inputs and targets in the dataset. + These keys are used to extract values from the dataset dictionary in + the `compute_ntk` method. + Note that the first key has to refer the input data and the second key + to the targets / labels of the dataset. + """ + super().__init__( + apply_fn=apply_fn, + batch_size=batch_size, + ntk_implementation=ntk_implementation, + trace_axes=trace_axes, + store_on_device=store_on_device, + flatten=flatten, + data_keys=data_keys, + ) + + self.class_labels = class_labels + + # Compute all possible combinations of the class labels + self.label_combinations = self._compute_combinations() + + def _reduce_data_to_labels(self, dataset: dict) -> dict: + """ + Reduce the dataset to only contain the selected class labels. + + Parameters + ---------- + dataset : dict + The dataset containing the inputs and targets. + + Returns + ------- + dict + The dataset containing only the selected class labels. + """ + targets = dataset[self.data_keys[1]] + + if len(targets.shape) > 1: + # If one-hot encoding is used, convert it to class labels + if targets.shape[1] > 1: + targets = np.argmax(targets, axis=1) + # If the targets are already class labels, squeeze the array + elif targets.shape[1] == 1: + targets = np.squeeze(targets, axis=1) + + mask = np.isin(targets, np.array(self.class_labels)) + dataset_reduced = {} + for key, value in dataset.items(): + dataset_reduced[key] = np.compress(mask, value, axis=0) + + return dataset_reduced + + def _get_label_indices(self, dataset: dict) -> List[np.ndarray]: + """ + Group the data by label and return the indices of the samples to use for the + NTK computation. + + Parameters + ---------- + dataset : dict + The dataset containing the inputs and targets. + + Returns + ------- + sample_indices : dict + A dictionary containing the indices of the samples for each class, with + the class label as the key. + """ + targets = dataset[self.data_keys[1]] + + if len(targets.shape) > 1: + # If one-hot encoding is used, convert it to class labels + if targets.shape[1] > 1: + targets = np.argmax(targets, axis=1) + # If the targets are already class labels, squeeze the array + elif targets.shape[1] == 1: + targets = np.squeeze(targets, axis=1) + + _indices = np.arange(targets.shape[0]) + sample_indices = {} + + for class_label in self.class_labels: + # Create mask for samples of the current class + mask = targets == class_label + indices = np.compress(mask, _indices, axis=0) + sample_indices[int(class_label)] = indices + + return sample_indices + + def _compute_combinations(self) -> List[np.ndarray]: + """ + Compute all possible combinations of the class labels. + + The combinations are computed for all possible pairs of class labels contained + in the ` + + Parameters + ---------- + sample_indices : dict + The indices of the samples to use for the NTK computation. + + Returns + ------- + List[np.ndarray] + The NTK matrix. + """ + label_combinations = [] + # Compute all possible combinations of the class labels + for i in range(1, len(self.class_labels) + 1): + label_combinations.extend(combinations(self.class_labels, i)) + + return label_combinations + + def _take_sub_ntk( + self, ntk: np.ndarray, label_indices: dict, combination: tuple + ) -> np.ndarray: + """ + Take a submatrix of the NTK matrix using np.ix_. + + Parameters + ---------- + ntk : np.ndarray + The NTK matrix. + label_indices : dict + A dictionary containing the indices of the samples for each class, with + the class label as the key. + combinations : tuple + The combination of class labels to use for the submatrix. + + Returns + ------- + np.ndarray + The submatrix of the NTK matrix. + """ + indices = [label_indices[label] for label in combination] + indices = np.concatenate(indices) + + # Check if flattening was performed + if self._is_flattened: + ntk = unflatten_rank_4_tensor(ntk, self._ntk_shape) + + ntk_sub = ntk[np.ix_(indices, indices)] + + if self.flatten: + ntk_sub, _ = flatten_rank_4_tensor(ntk_sub) + + return ntk_sub + + def _compute_ntk(self, params: dict, x_i: np.ndarray) -> np.ndarray: + """ + Compute the NTK for the neural network. + + Parameters + ---------- + params : dict + The parameters of the neural network. + x_i : np.ndarray + The input to the neural network. + + Returns + ------- + np.ndarray + The NTK matrix. + """ + ntk = self.empirical_ntk(x_i, None, params) + ntk = self._check_shape(ntk) + return ntk + + def compute_ntk(self, params: dict, dataset_i: dict) -> List[np.ndarray]: + """ + Compute the Neural Tangent Kernel (NTK) for the neural network. + + Note + ---- + This method only accepts a single dataset for the NTK computation. This means + both axes of the NTK matrix correspond to the same dataset. + + Parameters + ---------- + params : dict + The parameters of the neural network. + dataset_i : dict + The input dataset for the NTK computation. + + Returns + ------- + List[np.ndarray] + List of NTK matrices for all possible class combinations. + What class combinations each NTK corresponds to can be found in the + `label_combinations` attribute. + """ + + # Reduce the dataset to the selected class labels + dataset_reduced = self._reduce_data_to_labels(dataset_i) + + # Compute the NTK for the reduced dataset + ntk = self._compute_ntk(params, dataset_reduced[self.data_keys[0]]) + + # Get the label indices referencing to the reduced dataset + label_indices = self._get_label_indices(dataset_reduced) + + # Create copies of the NTK for all possible class combinations + ntks = [] + for combination in self.label_combinations: + sub_ntk = self._take_sub_ntk(ntk, label_indices, combination) + ntks.append(sub_ntk) + + return ntks diff --git a/znnl/analysis/jax_ntk_subsampling.py b/znnl/analysis/jax_ntk_subsampling.py new file mode 100644 index 0000000..836ef9c --- /dev/null +++ b/znnl/analysis/jax_ntk_subsampling.py @@ -0,0 +1,229 @@ +""" +ZnNL: A Zincwarecode package. + +License +------- +This program and the accompanying materials are made available under the terms +of the Eclipse Public License v2.0 which accompanies this distribution, and is +available at https://www.eclipse.org/legal/epl-v20.html + +SPDX-License-Identifier: EPL-2.0 + +Copyright Contributors to the Zincwarecode Project. + +Contact Information +------------------- +email: zincwarecode@gmail.com +github: https://github.com/zincware +web: https://zincwarecode.com/ + +Citation +-------- +If you use this module please cite us with: + +Summary +------- +""" + +from typing import Callable, List, Optional + +import jax.numpy as np +import neural_tangents as nt +from jax import random +from jax.tree_util import tree_map as jmap + +from znnl.analysis.jax_ntk import JAXNTKComputation + + +class JAXNTKSubsampling(JAXNTKComputation): + """ + Class for computing the empirical Neural Tangent Kernel (NTK) using the + neural-tangents library (implemented in JAX) with subsampling. + + This class is a subclass of JAXNTKComputation and adds the functionality of + subsampling the data before computing the NTK. + Subsampling is useful when the data is too large to compute the NTK on the + entire dataset. + Subsampling is done by splitting the data randomly into batches of size + `ntk_size` and computing the NTK on each part separately. + This is equivalent to computing block-diagonal elements of the NTK of size + `ntk_size`. + The `compute_ntk` method of this class will return a list of len(data) // ntk_size + NTK matrices. + """ + + def __init__( + self, + apply_fn: Callable, + ntk_size: int, + seed: int = 0, + batch_size: int = 10, + ntk_implementation: nt.NtkImplementation = None, + trace_axes: tuple = (), + store_on_device: bool = False, + flatten: bool = True, + data_keys: Optional[List[str]] = None, + ): + """ + Constructor the JAX NTK computation class with subsampling. + + Parameters + ---------- + apply_fn : Callable + The function that applies the neural network to an input. + This function should be implemented using JAX. It should take in a + dictionary of parameters (and possibly other arguments) and return the + output of the neural network. + For models taking in `batch_stats` the apply function should look like:: + def apply_fn(params, x): + return model.apply( + params, x, train=False, mutable=['batch_stats'] + )[0] + n_parts : int + Number of sub-samples to use for the NTK computation. + ntk_size : int + Size of the NTK sub-samples. + batch_size : int + Size of batch to use in the NTk calculation. + Note that this has to fit with the set `ntk_size`. + ntk_implementation : Union[None, NtkImplementation] (default = None) + Implementation of the NTK computation. + The implementation depends on the trace_axes and the model + architecture. The default does automatically take into account the + trace_axes. For trace_axes=() the default is NTK_VECTOR_PRODUCTS, + for all other cases including trace_axes=(-1,) the default is + JACOBIAN_CONTRACTION. For more specific use cases, the user can + set the implementation manually. + Information about the implementation and specific requirements can be + found in the neural_tangents documentation. + trace_axes : Union[int, Sequence[int]] + Tracing over axes of the NTK. + The default value is trace_axes(-1,), which reduces the NTK to a tensor + of rank 2. + For a full NTK set trace_axes=(). + store_on_device : bool, default True + Whether to store the NTK on the device or not. + This should be set False for large NTKs that do not fit in GPU memory. + flatten : bool, default True + If True, the NTK shape is checked and flattened into a 2D matrix, if + required. + data_keys : List[str], default ["inputs", "targets"] + The keys used to define inputs and targets in the dataset. + These keys are used to extract values from the dataset dictionary in + the `compute_ntk` method. + Note that the first key has to refer the input data and the second key + to the targets / labels of the dataset. + """ + super().__init__( + apply_fn=apply_fn, + batch_size=batch_size, + ntk_implementation=ntk_implementation, + trace_axes=trace_axes, + store_on_device=store_on_device, + flatten=flatten, + data_keys=data_keys, + ) + self.ntk_size = ntk_size + self.key = random.PRNGKey(seed) + + self._sample_indices: List[np.ndarray] = [] + self.n_parts = None + + def _get_sample_indices(self, x: np.ndarray) -> List[np.ndarray]: + """ + Split the data into `n_parts` parts of size `ntk_size`. + + Parameters + ---------- + x : np.ndarray + The input data. + + Returns + ------- + List[np.ndarray] + A list of indices for each part of the data. Each index array has + length `ntk_size`. + """ + data_len = x.shape[0] + self.n_parts = data_len // self.ntk_size + + key, self.key = random.split(self.key) + + indices = random.permutation(key, np.arange(data_len)) + + return [ + indices[i * self.ntk_size : (i + 1) * self.ntk_size] + for i in range(self.n_parts) + ] + + def _subsample_data(self, x: np.ndarray) -> np.ndarray: + """ + Subsample the data based on self._sample_indices. + + Parameters + ---------- + x : np.ndarray + The input data. + + Returns + ------- + np.ndarray + The subsampled data. + """ + return [np.take(x, indices, axis=0) for indices in self._sample_indices] + + def _compute_ntk( + self, params: dict, x_i: np.ndarray, x_j: Optional[np.ndarray] = None + ) -> np.ndarray: + """ + Compute the NTK for the neural network. + + Parameters + ---------- + params : dict + The parameters of the neural network. + x_i : np.ndarray + The input to the neural network. + x_j : np.ndarray + The input to the neural network. + + Returns + ------- + np.ndarray + The NTK matrix. + """ + ntk = self.empirical_ntk(x_i, x_j, params) + ntk = self._check_shape(ntk) + return ntk + + def compute_ntk( + self, params: dict, dataset_i: dict, dataset_j: Optional[dict] = None + ) -> List[np.ndarray]: + """ + Compute the Neural Tangent Kernel (NTK) for the neural network. + + Parameters + ---------- + params : dict + The parameters of the neural network. + dataset_i : dict + The input dataset for the NTK computation. + dataset_j : Optional[dict] + Optional input dataset for the NTK computation. + + Returns + ------- + List[np.ndarray] + The NTK matrix. + """ + x_i = dataset_i[self.data_keys[0]] + x_j = dataset_j[self.data_keys[0]] if dataset_j is not None else None + + self._sample_indices = self._get_sample_indices(x_i) + x_i = self._subsample_data(x_i) + + x_j = self._subsample_data(x_j) if x_j is not None else [None] * self.n_parts + + ntks = jmap(lambda x_i, x_j: self._compute_ntk(params, x_i, x_j), x_i, x_j) + + return ntks diff --git a/znnl/models/flax_model.py b/znnl/models/flax_model.py index 6aa3726..73248cc 100644 --- a/znnl/models/flax_model.py +++ b/znnl/models/flax_model.py @@ -26,12 +26,11 @@ """ import logging -from typing import Callable, List, Sequence, Union +from typing import Callable, List import jax import jax.numpy as np from flax import linen as nn -from neural_tangents import NtkImplementation from znnl.models.jax_model import JaxModel @@ -74,12 +73,8 @@ def __init__( self, optimizer: Callable, input_shape: tuple, - batch_size: int = 10, layer_stack: List[nn.Module] = None, flax_module: nn.Module = None, - trace_axes: Union[int, Sequence[int]] = (-1,), - ntk_implementation: Union[None, NtkImplementation] = None, - store_on_device: bool = True, seed: int = None, ): """ @@ -94,28 +89,8 @@ def __init__( cross-compatibility is not assured. input_shape : tuple Shape of the NN input. - batch_size : int - Size of batch to use in the NTk calculation. flax_module : nn.Module Flax module to use instead of building one from scratch here. - trace_axes : Union[int, Sequence[int]] - Tracing over axes of the NTK. - The default value is trace_axes(-1,), which reduces the NTK to a tensor - of rank 2. - For a full NTK set trace_axes=(). - ntk_implementation : Union[None, NtkImplementation] (default = None) - Implementation of the NTK computation. - The implementation depends on the trace_axes and the model - architecture. The default does automatically take into account the - trace_axes. For trace_axes=() the default is NTK_VECTOR_PRODUCTS, - for all other cases including trace_axes=(-1,) the default is - JACOBIAN_CONTRACTION. For more specific use cases, the user can - set the implementation manually. - Information about the implementation and specific requirements can be - found in the neural_tangents documentation. - store_on_device : bool, default True - Whether to store the NTK on the device or not. - This should be set False for large NTKs that do not fit in GPU memory. seed : int, default None Random seed for the RNG. Uses a random int if not specified. """ @@ -137,10 +112,6 @@ def __init__( optimizer=optimizer, input_shape=input_shape, seed=seed, - trace_axes=trace_axes, - ntk_implementation=ntk_implementation, - ntk_batch_size=batch_size, - store_on_device=store_on_device, ) def _init_params(self, kernel_init: Callable = None, bias_init: Callable = None): @@ -166,7 +137,7 @@ def _init_params(self, kernel_init: Callable = None, bias_init: Callable = None) return params - def _ntk_apply_fn(self, params, inputs: np.ndarray): + def ntk_apply_fn(self, params, inputs: np.ndarray): """ Return an NTK capable apply function. diff --git a/znnl/models/huggingface_flax_model.py b/znnl/models/huggingface_flax_model.py index 9882890..7c03957 100644 --- a/znnl/models/huggingface_flax_model.py +++ b/znnl/models/huggingface_flax_model.py @@ -27,17 +27,12 @@ """ import logging -from typing import Any, Callable, List, Sequence, Union +from typing import Any, Callable import jax.numpy as np -import optax -from flax import linen as nn -from flax.training.train_state import TrainState -from neural_tangents import NtkImplementation from transformers import FlaxPreTrainedModel from znnl.models.jax_model import JaxModel -from znnl.optimizers.trace_optimizer import TraceOptimizer logger = logging.getLogger(__name__) @@ -51,10 +46,6 @@ def __init__( self, pre_built_model: FlaxPreTrainedModel, optimizer: Callable, - batch_size: int = 10, - trace_axes: Union[int, Sequence[int]] = (-1,), - store_on_device: bool = True, - ntk_implementation: Union[None, NtkImplementation] = None, ): """ Constructor of a HF flax model. @@ -68,26 +59,6 @@ def __init__( cross-compatibility is not assured. input_shape : tuple Shape of the NN input. - batch_size : int - Size of batch to use in the NTk calculation. - trace_axes : Union[int, Sequence[int]] - Tracing over axes of the NTK. - The default value is trace_axes(-1,), which reduces the NTK to a tensor - of rank 2. - For a full NTK set trace_axes=(). - ntk_implementation : Union[None, NtkImplementation] (default = None) - Implementation of the NTK computation. - The implementation depends on the trace_axes and the model - architecture. The default does automatically take into account the - trace_axes. For trace_axes=() the default is NTK_VECTOR_PRODUCTS, - for all other cases including trace_axes=(-1,) the default is - JACOBIAN_CONTRACTION. For more specific use cases, the user can - set the implementation manually. - Information about the implementation and specific requirements can be - found in the neural_tangents documentation. - store_on_device : bool, default True - Whether to store the NTK on the device or not. - This should be set False for large NTKs that do not fit in GPU memory. """ logger.info( "Flax models have occasionally experienced memory allocation issues on " @@ -102,13 +73,9 @@ def __init__( super().__init__( pre_built_model=pre_built_model, optimizer=optimizer, - trace_axes=trace_axes, - ntk_batch_size=batch_size, - store_on_device=store_on_device, - ntk_implementation=ntk_implementation, ) - def _ntk_apply_fn(self, params: dict, inputs: np.ndarray): + def ntk_apply_fn(self, params: dict, inputs: np.ndarray): """ Return an NTK capable apply function. diff --git a/znnl/models/jax_model.py b/znnl/models/jax_model.py index eb1448b..4160ba3 100644 --- a/znnl/models/jax_model.py +++ b/znnl/models/jax_model.py @@ -25,12 +25,11 @@ ------- """ -from typing import Any, Callable, Optional, Sequence, Union +from typing import Any, Callable, Optional, Union import jax import jax.numpy as np import jax.random -import neural_tangents as nt import optax from flax.training.train_state import TrainState from transformers import FlaxPreTrainedModel @@ -70,11 +69,7 @@ def __init__( optimizer: Union[Callable, TraceOptimizer], input_shape: Optional[tuple] = None, seed: Optional[int] = None, - ntk_batch_size: int = 10, - trace_axes: Union[int, Sequence[int]] = (), - store_on_device: bool = True, pre_built_model: Union[None, FlaxPreTrainedModel] = None, - ntk_implementation: Union[None, nt.NtkImplementation] = None, ): """ Construct a znrnd model. @@ -87,29 +82,9 @@ def __init__( Shape of the NN input. Required if no pre-built model is passed. seed : int, default None Random seed for the RNG. Uses a random int if not specified. - ntk_batch_size : Optional[int], default 10 - Batch size to use in the NTK computation. - trace_axes : Union[int, Sequence[int]] - Tracing over axes of the NTK. - The default value is trace_axes=(), providing the full NTK of rank 4. - For a traced NTK set trace_axes=(-1,), which reduces the NTK to a - tensor of rank 2. - store_on_device : bool, default True - Whether to store the NTK on the device or not. - This should be set False for large NTKs that do not fit in GPU memory. pre_built_model : Union[None, FlaxPreTrainedModel] (default = None) Pre-built model to use instead of building one from scratch here. So far, this is only implemented for Hugging Face flax models. - ntk_implementation : Union[None, nt.NtkImplementation] (default = None) - Implementation of the NTK computation. - The implementation depends on the trace_axes and the model - architecture. The default does automatically take into account the - trace_axes. For trace_axes=() the default is NTK_VECTOR_PRODUCTS, - for all other cases including trace_axes=(-1,) the default is - JACOBIAN_CONTRACTION. For more specific use cases, the user can - set the implementation manually. - Information about the implementation and specific requirements can be - found in the neural_tangents documentation. """ self.optimizer = optimizer @@ -131,21 +106,6 @@ def __init__( else: self.model_state = self._create_train_state(params=pre_built_model.params) - # Prepare NTK calculation - if not ntk_implementation: - if trace_axes == (): - ntk_implementation = nt.NtkImplementation.NTK_VECTOR_PRODUCTS - else: - ntk_implementation = nt.NtkImplementation.JACOBIAN_CONTRACTION - self.empirical_ntk = nt.batch( - nt.empirical_ntk_fn( - f=self._ntk_apply_fn, - trace_axes=trace_axes, - implementation=ntk_implementation, - ), - batch_size=ntk_batch_size, - store_on_device=store_on_device, - ) self.apply_jit = jax.jit(self.apply) def init_model( @@ -205,7 +165,7 @@ def _create_train_state(self, params: dict) -> TrainState: ) return train_state - def _ntk_apply_fn(self, params: dict, inputs: np.ndarray): + def ntk_apply_fn(self, params: dict, inputs: np.ndarray): """ Apply function used in the NTK computation. @@ -242,50 +202,6 @@ def train_apply_fn(self, params: dict, inputs: np.ndarray): """ raise NotImplementedError("Implemented in child class") - def compute_ntk( - self, - x_i: np.ndarray, - x_j: np.ndarray = None, - infinite: bool = False, - ): - """ - Compute the NTK matrix for the model. - - Parameters - ---------- - x_i : np.ndarray - Dataset for which to compute the NTK matrix. - x_j : np.ndarray (optional) - Dataset for which to compute the NTK matrix. - infinite : bool (default = False) - If true, compute the infinite width limit as well. - - Returns - ------- - NTK : dict - The NTK matrix for both the empirical and infinite width computation. - """ - if x_j is None: - x_j = x_i - empirical_ntk = self.empirical_ntk( - x_i, - x_j, - { - "params": self.model_state.params, - "batch_stats": self.model_state.batch_stats, - }, - ) - - if infinite: - try: - infinite_ntk = self.kernel_fn(x_i, x_j, "ntk") - except AttributeError: - raise NotImplementedError("Infinite NTK not available for this model.") - else: - infinite_ntk = None - - return {"empirical": empirical_ntk, "infinite": infinite_ntk} - def __call__(self, feature_vector: np.ndarray): """ Call the network. diff --git a/znnl/models/nt_model.py b/znnl/models/nt_model.py index 688b491..a51ce93 100644 --- a/znnl/models/nt_model.py +++ b/znnl/models/nt_model.py @@ -26,12 +26,10 @@ """ import logging -from typing import Any, Callable, Sequence, Union +from typing import Callable, Union import jax import jax.numpy as np -import neural_tangents as nt -from neural_tangents import NtkImplementation from neural_tangents.stax import serial from znnl.models.jax_model import JaxModel @@ -50,10 +48,6 @@ def __init__( optimizer: Union[Callable, TraceOptimizer], input_shape: tuple, nt_module: serial = None, - batch_size: int = 10, - trace_axes: Union[int, Sequence[int]] = (-1,), - ntk_implementation: Union[None, NtkImplementation] = None, - store_on_device: bool = True, seed: int = None, ): """ @@ -68,42 +62,17 @@ def __init__( Shape of the NN input. nt_module : serial NT model used. - batch_size : int, default 10 - Batch size to use in the NTK computation. - trace_axes : Union[int, Sequence[int]] - Tracing over axes of the NTK. - The default value is trace_axes(-1,), which reduces the NTK to a tensor - of rank 2. - For a full NTK set trace_axes=(). - ntk_implementation : Union[None, NtkImplementation] (default = None) - Implementation of the NTK computation. - The implementation depends on the trace_axes and the model - architecture. The default does automatically take into account the - trace_axes. For trace_axes=() the default is NTK_VECTOR_PRODUCTS, - for all other cases including trace_axes=(-1,) the default is - JACOBIAN_CONTRACTION. For more specific use cases, the user can - set the implementation manually. - Information about the implementation and specific requirements can be - found in the neural_tangents documentation. - store_on_device : bool, default True - Whether to store the NTK on the device or not. - This should be set False for large NTKs that do not fit in GPU memory. seed : int, default None Random seed for the RNG. Uses a random int if not specified. """ self.init_fn = nt_module[0] self.apply_fn = jax.jit(nt_module[1]) - self.kernel_fn = nt.batch(nt_module[2], batch_size=batch_size) # Save input parameters, call self.init_model super().__init__( optimizer=optimizer, input_shape=input_shape, seed=seed, - trace_axes=trace_axes, - ntk_batch_size=batch_size, - store_on_device=store_on_device, - ntk_implementation=ntk_implementation, ) def _init_params(self, kernel_init: Callable = None, bias_init: Callable = None): @@ -168,7 +137,7 @@ def train_apply_fn(self, params: dict, inputs: np.ndarray): """ return self.apply_fn(params["params"], inputs) - def _ntk_apply_fn(self, params: dict, inputs: np.ndarray): + def ntk_apply_fn(self, params: dict, inputs: np.ndarray): """ NTK Apply function for the neural_tangents module. diff --git a/znnl/optimizers/trace_optimizer.py b/znnl/optimizers/trace_optimizer.py index dbe5723..434995c 100644 --- a/znnl/optimizers/trace_optimizer.py +++ b/znnl/optimizers/trace_optimizer.py @@ -83,7 +83,8 @@ def apply_optimizer( # Check if the update should be performed. if epoch % self.rescale_interval == 0: # Compute the ntk trace. - ntk = ntk_fn(data_set)["empirical"] + ntk = ntk_fn({"params": model_state.params}, data_set) + ntk = np.array(ntk).mean(axis=0) trace = np.trace(ntk) # Create the new optimizer. diff --git a/znnl/training_recording/__init__.py b/znnl/training_recording/__init__.py index 0d62617..c31252f 100644 --- a/znnl/training_recording/__init__.py +++ b/znnl/training_recording/__init__.py @@ -26,6 +26,9 @@ """ from znnl.training_recording.data_storage import DataStorage -from znnl.training_recording.jax_recording import JaxRecorder +from znnl.training_recording.papyrus_jax_recording import JaxRecorder -__all__ = [JaxRecorder.__name__, DataStorage.__name__] +__all__ = [ + JaxRecorder.__name__, + DataStorage.__name__, +] diff --git a/znnl/training_recording/jax_recording.py b/znnl/training_recording/jax_recording.py deleted file mode 100644 index 0efcff9..0000000 --- a/znnl/training_recording/jax_recording.py +++ /dev/null @@ -1,688 +0,0 @@ -""" -ZnNL: A Zincwarecode package. - -License -------- -This program and the accompanying materials are made available under the terms -of the Eclipse Public License v2.0 which accompanies this distribution, and is -available at https://www.eclipse.org/legal/epl-v20.html - -SPDX-License-Identifier: EPL-2.0 - -Copyright Contributors to the Zincwarecode Project. - -Contact Information -------------------- -email: zincwarecode@gmail.com -github: https://github.com/zincware -web: https://zincwarecode.com/ - -Citation --------- -If you use this module please cite us with: - -Summary -------- -""" - -import logging -from dataclasses import dataclass, make_dataclass -from os import path -from pathlib import Path -from typing import Optional - -import jax.numpy as np -import numpy as onp - -from znnl.accuracy_functions.accuracy_function import AccuracyFunction -from znnl.analysis.eigensystem import EigenSpaceAnalysis -from znnl.analysis.entropy import EntropyAnalysis -from znnl.analysis.loss_fn_derivative import LossDerivative -from znnl.loss_functions import SimpleLoss -from znnl.models.jax_model import JaxModel -from znnl.training_recording.data_storage import DataStorage -from znnl.utils.matrix_utils import ( - calculate_trace, - compute_magnitude_density, - flatten_rank_4_tensor, - normalize_gram_matrix, -) - -logger = logging.getLogger(__name__) - - -@dataclass -class JaxRecorder: - """ - Class for recording jax training. - - Each property is has an accompanying array which is set to None. - At instantiation, this can be populated correctly or resized in the event of - re-training. - Only loss is set default True. - - Attributes - ---------- - name : str - Name of the recorder. - storage_path : str - Where to store the data on disk. - chunk_size : int - Amount of data to keep in memory before it is dumped to a hdf5 database. - loss : bool (default=True) - If true, loss will be recorded. - accuracy : bool (default=False) - If true, accuracy will be recorded. - network_predictions : bool (default=False) - If true, network predictions will be recorded. - ntk : bool (default=False) - If true, the ntk will be recorded. Warning, large overhead. - covariance_ntk : bool (default = False) - If true, the covariance of the ntk will be recorded. - Warning, large overhead. - magnitude_ntk : bool (default = False) - If true, gradient magnitudes of the ntk will be recorded. - Warning, large overhead. - entropy : bool (default=False) - If true, entropy will be recorded. Warning, large overhead. - covariance_entropy : bool (default=False) - If true, the entropy of the covariance ntk will be recorded. - Warning, large overhead. - magnitude_variance : bool (default=False) - If true, the variance of the gradient magnitudes of the ntk will be - recorded. - magnitude_entropy : bool (default=False) - If true, the entropy of the gradient magnitudes of the ntk will be recorded. - Warning, large overhead. - eigenvalues : bool (default=False) - If true, eigenvalues will be recorded. Warning, large overhead. - loss_derivative : bool (default=False) - If true, the derivative of the loss function with respect to the network - output will be recorded. - update_rate : int (default=1) - How often the values are updated. - flatten_ntk : bool (default=False) - If true, an NTK of rank 4 will be flattened to a rank 2 tensor. - In case of an NTK of rank 2, this has no effect. - - Notes - ----- - Currently the options are hard-coded. In the future, we will work towards allowing - for arbitrary computations to be added, for example, two losses. - """ - - # Recorder Attributes - name: str = "my_recorder" - storage_path: str = "./" - chunk_size: int = 100 - flatten_ntk: bool = True - - # Model Loss - loss: bool = True - _loss_array: list = None - - # Model accuracy - accuracy: bool = False - _accuracy_array: list = None - - # Model predictions - network_predictions: bool = False - _network_predictions_array: list = None - - # NTK Matrix - ntk: bool = False - _ntk_array: list = None - - # Covariance NTK Matrix - covariance_ntk: bool = False - _covariance_ntk_array: list = None - - # Magnitude NTK array - magnitude_ntk: bool = False - _magnitude_ntk_array: list = None - - # Entropy of the model - entropy: bool = False - _entropy_array: list = None - - # Covariance Entropy of the model - covariance_entropy: bool = False - _covariance_entropy_array: list = None - - # Magnitude Variance of the model - magnitude_variance: bool = False - _magnitude_variance_array: list = None - - # Magnitude Entropy of the model - magnitude_entropy: bool = False - _magnitude_entropy_array: list = None - - # Model eigenvalues - eigenvalues: bool = False - _eigenvalues_array: list = None - - # Model trace - trace: bool = False - _trace_array: list = None - - # Loss derivative - loss_derivative: bool = False - _loss_derivative_array: list = None - - # Class helpers - update_rate: int = 1 - _loss_fn: SimpleLoss = None - _accuracy_fn: AccuracyFunction = None - _selected_properties: list = None - _model: JaxModel = None - _data_set: dict = None - _compute_ntk: bool = False # Helps to know if we can compute it once and share. - _compute_loss_derivative: bool = False - _loss_derivative_fn: LossDerivative = False - _index_count: int = 0 # Helps to avoid problems with non-1 update rates. - _data_storage: DataStorage = None # For writing to disk. - _ntk_rank: Optional[int] = None # Rank of the NTK matrix. - - def _read_selected_attributes(self): - """ - Helper function to read selected attributes. - """ - # populate the class attribute - self._selected_properties = [ - value - for value in list(vars(self)) - if value[0] != "_" and vars(self)[value] is True and value != "flatten_ntk" - ] - - def _build_or_resize_array(self, name: str, overwrite: bool): - """ - Build or resize an array. - - Parameters - ---------- - name : str - Name of array. Needed to check for resizing. - overwrite : bool - If True, arrays will not be resized but overwritten. - - Returns - ------- - A np zeros array or a resized array padded with zeros. - """ - # Check if array exists - data = getattr(self, name) - - # Create array if none or if overwrite is set true - if data is None or overwrite: - data = [] - - return data - - def instantiate_recorder(self, data_set: dict = None, overwrite: bool = False): - """ - Prepare the recorder for training. - - Parameters - ---------- - data_set : dict (default=None) - Data to record during training. - overwrite : bool (default=False) - If true and there is data already in the array, this will be removed and - a new array created. - - Returns - ------- - Populates the array attributes of the dataclass. - """ - # Create the data storage manager. - _storage_path = path.join(self.storage_path, self.name) - self._data_storage = DataStorage(Path(_storage_path)) - - if data_set: - # Update simple attributes - self._data_set = data_set - if self._data_set is None and data_set is None: - raise AttributeError( - "No data set given for the recording process." - "Instantiate the recorder with a data set." - ) - - # populate the class attribute - self._read_selected_attributes() - - # Instantiate arrays - all_attributes = self.__dict__ - for item in self._selected_properties: - if item == "ntk": - all_attributes[f"_{item}_array"] = self._build_or_resize_array( - f"_{item}_array", overwrite - ) - else: - all_attributes[f"_{item}_array"] = self._build_or_resize_array( - f"_{item}_array", overwrite - ) - - # If over-writing, reset the index count - if overwrite: - self._index_count = 0 - - # Check if we need an NTK computation and update the class accordingly - if any( - [ - "ntk" in self._selected_properties, - "covariance_ntk" in self._selected_properties, - "magnitude_ntk" in self._selected_properties, - "entropy" in self._selected_properties, - "magnitude_entropy" in self._selected_properties, - "magnitude_variance" in self._selected_properties, - "covariance_entropy" in self._selected_properties, - "eigenvalues" in self._selected_properties, - "trace" in self._selected_properties, - ] - ): - self._compute_ntk = True - - if "loss_derivative" in self._selected_properties: - self._loss_derivative_fn = LossDerivative(self._loss_fn) - - def update_recorder(self, epoch: int, model: JaxModel): - """ - Update the values stored in the recorder. - - Parameters - ---------- - epoch : int - Current epoch of the model. - model : JaxModel - Model to use in the update. - - Returns - ------- - Updates the chosen class attributes depending on the user requirements. - """ - # Check if we need to record and if so, record - if epoch % self.update_rate == 0: - # Update here to expose to other methods. - self._model = model - - parsed_data = {"epoch": self._index_count} - - # Add epoch to the parsed data - # Compute representations here to avoid repeated computation. - predictions = self._model(self._data_set["inputs"]) - if type(predictions) is tuple: - predictions = predictions[0] - parsed_data["predictions"] = predictions - - # Compute ntk here to avoid repeated computation. - if self._compute_ntk: - try: - ntk = self._model.compute_ntk( - self._data_set["inputs"], infinite=False - )["empirical"] - self._ntk_rank = len(ntk.shape) - if self.flatten_ntk and self._ntk_rank == 4: - ntk = flatten_rank_4_tensor(ntk) - parsed_data["ntk"] = ntk - except NotImplementedError: - logger.info( - "NTK calculation is not yet available for this model. Removing " - "it from this recorder." - ) - self.ntk = False - self.covariance_ntk = False - self.magnitude_ntk = False - self.entropy = False - self.magnitude_entropy = False - self.magnitude_variance = False - self.covariance_entropy = False - self.eigenvalues = False - self._read_selected_attributes() - - for item in self._selected_properties: - call_fn = getattr(self, f"_update_{item}") # get the callable function - - # Try to add data and resize if necessary. - call_fn(parsed_data) # call the function and update the property - - self._index_count += 1 # Update the index count. - else: - pass - - # Dump records if the index hits the chunk size. - if self._index_count == self.chunk_size: - self.dump_records() - - def dump_records(self): - """ - Dump recorded properties to hdf5 database. - """ - export_data = self._export_in_memory_data() # collect the in-memory data. - self._data_storage.write_data(export_data) - self.instantiate_recorder(self._data_set, overwrite=True) # clear data. - - def visualize_recorder(self): - """ - Display recorded values as web app. - - Returns - ------- - - """ - raise NotImplementedError("Not yet available in ZnRND.") - - @property - def loss_fn(self): - """ - The loss function property of a recorder. - - Returns - ------- - The loss function used in the recorder. - """ - return self._loss_fn - - @loss_fn.setter - def loss_fn(self, loss_fn: SimpleLoss): - """ - Setting a loss function for a recorder. - - Parameters - ---------- - loss_fn : SimpleLoss - Loss function used for recording. - """ - self._loss_fn = loss_fn - - @property - def accuracy_fn(self): - """ - The accuracy function property of a recorder. - - Returns - ------- - The accuracy function used in the recorder. - """ - return self._accuracy_fn - - @accuracy_fn.setter - def accuracy_fn(self, accuracy_fn: AccuracyFunction): - """ - Setting an accuracy function for a recorder. - - Parameters - ---------- - accuracy_fn : AccuracyFunction - Accuracy function used for recording. - """ - self._accuracy_fn = accuracy_fn - - def _update_loss(self, parsed_data: dict): - """ - Update the loss array. - - Parameters - ---------- - parsed_data : dict - Data computed before the update to prevent repeated calculations. - """ - self._loss_array.append( - self._loss_fn(parsed_data["predictions"], self._data_set["targets"]) - ) - - def _update_accuracy(self, parsed_data: dict): - """ - Update the accuracy array. - - Parameters - ---------- - parsed_data : dict - Data computed before the update to prevent repeated calculations. - """ - try: - self._accuracy_array.append( - self._accuracy_fn(parsed_data["predictions"], self._data_set["targets"]) - ) - except TypeError: - logger.info( - "There is no accuracy function defined in the training procedure, " - "switching this recording option off." - ) - self.accuracy = False - self._read_selected_attributes() - - def _update_network_predictions(self, parsed_data: dict): - """ - Update the network predictions array. - - Parameters - ---------- - parsed_data : dict - Data computed before the update to prevent repeated calculations. - """ - self._network_predictions_array.append(parsed_data["predictions"]) - - def _update_ntk(self, parsed_data: dict): - """ - Update the ntk array. - - Parameters - ---------- - parsed_data : dict - Data computed before the update to prevent repeated calculations. - """ - self._ntk_array.append(parsed_data["ntk"]) - - def _update_covariance_ntk(self, parsed_data: dict): - """ - Update the covariance ntk array. - - Parameters - ---------- - parsed_data : dict - Data computed before the update to prevent repeated calculations. - """ - cov_ntk = normalize_gram_matrix(parsed_data["ntk"]) - self._covariance_ntk_array.append(cov_ntk) - - def _update_magnitude_ntk(self, parsed_data: dict): - """ - Update the magnitude ntk array. - - Parameters - ---------- - parsed_data : dict - Data computed before the update to prevent repeated calculations. - """ - magnitude_dist = compute_magnitude_density(gram_matrix=parsed_data["ntk"]) - self._magnitude_ntk_array.append(magnitude_dist) - - def _update_entropy(self, parsed_data: dict): - """ - Update the entropy array. - - Parameters - ---------- - parsed_data : dict - Data computed before the update to prevent repeated calculations. - """ - calculator = EntropyAnalysis(matrix=parsed_data["ntk"]) - entropy = calculator.compute_von_neumann_entropy( - effective=False, normalize_eig=True - ) - self._entropy_array.append(entropy) - - def _update_covariance_entropy(self, parsed_data: dict): - """ - Update the entropy of the covariance NTK. - - The covariance ntk is defined as the of cosine similarities. For this each - entry of the NTK is re-scaled by the gradient amplitudes. - - Parameters - ---------- - parsed_data : dict - Data computed before the update to prevent repeated calculations. - """ - cov_ntk = normalize_gram_matrix(parsed_data["ntk"]) - calculator = EntropyAnalysis(matrix=cov_ntk) - entropy = calculator.compute_von_neumann_entropy( - effective=False, normalize_eig=True - ) - self._covariance_entropy_array.append(entropy) - - def _update_magnitude_entropy(self, parsed_data: dict): - """ - Update magnitude entropy of the NTK. - - The magnitude entropy is defined as the entropy of the normalized gradient - magnitudes. - - Parameters - ---------- - parsed_data : dict - Data computed before the update to prevent repeated calculations. - """ - magnitude_dist = compute_magnitude_density(gram_matrix=parsed_data["ntk"]) - entropy = EntropyAnalysis.compute_shannon_entropy(magnitude_dist) - self._magnitude_entropy_array.append(entropy) - - def _update_magnitude_variance(self, parsed_data: dict): - """ - Update the magnitude variance of the NTK. - - The magnitude variance is defined as the variance of the normalized gradient - magnitudes. - As the normalization to obtain the magnitude distribution is done by dividing - by the sum of the magnitudes, the variance is calculated as: - - magnitude_variance = var(magnitudes * magnitudes.shape[0]) - - This ensures that the variance is not dependent on the number entries in the - magnitude distribution. - It is equivalent to the following: - - ntk_diag = sqrt( diag(ntk) ) - magnitude_variance = var( diag / mean(ntk_diag) ) - - Parameters - ---------- - parsed_data : dict - Data computed before the update to prevent repeated calculations. - """ - magnitude_dist = compute_magnitude_density(gram_matrix=parsed_data["ntk"]) - magvar = np.var(magnitude_dist * magnitude_dist.shape[0]) - self._magnitude_variance_array.append(magvar) - - def _update_eigenvalues(self, parsed_data: dict): - """ - Update the eigenvalue array. - - Parameters - ---------- - parsed_data : dict - Data computed before the update to prevent repeated calculations. - """ - calculator = EigenSpaceAnalysis(matrix=parsed_data["ntk"]) - eigenvalues = calculator.compute_eigenvalues(normalize=False) - self._eigenvalues_array.append(eigenvalues) - - def _update_trace(self, parsed_data: dict): - """ - Update the trace of the NTK. - - The trace of the NTK is computed as the mean of the diagonal elements of the - NTK. - - Parameters - ---------- - parsed_data : dict - Data computed before the update to prevent repeated calculations. - """ - trace = calculate_trace(parsed_data["ntk"], normalize=True) - self._trace_array.append(trace) - - def _update_loss_derivative(self, parsed_data): - """ - Update the loss derivative array. - - The loss derivative records the derivative of the loss function with respect to - the network output, returning a vector of the same shape as the network output. - - Parameters - ---------- - parsed_data : dict - Data computed before the update to prevent repeated calculations. - """ - vector_loss_derivative = self._loss_derivative_fn.calculate( - parsed_data["predictions"], self._data_set["targets"] - ) - self._loss_derivative_array.append(vector_loss_derivative) - - def gather_recording(self, selected_properties: list = None) -> dataclass: - """ - Export a dataclass of used properties. - - Parameters - ---------- - selected_properties : list default = None - List of parameters to export. If None, all available are exported. - - Returns - ------- - dataset : object - A dataclass of only the data recorder during the training. - """ - if selected_properties is None: - selected_properties = self._selected_properties - else: - # Check if we can collect all the data. - comparison = [i in self._selected_properties for i in selected_properties] - # Throw away data that was not saved in the first place. - if not all(comparison): - logger.info( - "You have asked for properties that were not recorded. Removing" - " the impossible elements." - ) - selected_properties = onp.array(selected_properties)[ - onp.array(comparison).astype(int) == 1 - ] - - DataSet = make_dataclass( - "DataSet", [(item, onp.ndarray) for item in selected_properties] - ) - selected_data = { - item: onp.array(vars(self)[f"_{item}_array"]) - for item in selected_properties - } - - # Try to load some data from the hdf5 database. - try: - db_data = self._data_storage.fetch_data(selected_properties) - # Add db data to the selected data dict. - for item, data in selected_data.items(): - selected_data[item] = onp.concatenate((db_data[item], data), axis=0) - - except FileNotFoundError: # There is no database. - pass - - return DataSet(**selected_data) - - def _export_in_memory_data(self) -> dataclass: - """ - Export a dataclass of used properties. - - Returns - ------- - dataset : object - A dataclass of only the data recorder during the training. - """ - DataSet = make_dataclass( - "DataSet", [(item, onp.ndarray) for item in self._selected_properties] - ) - selected_data = { - item: onp.array(vars(self)[f"_{item}_array"]) - for item in self._selected_properties - } - - return DataSet(**selected_data) diff --git a/znnl/training_recording/papyrus_jax_recording.py b/znnl/training_recording/papyrus_jax_recording.py new file mode 100644 index 0000000..1952d99 --- /dev/null +++ b/znnl/training_recording/papyrus_jax_recording.py @@ -0,0 +1,255 @@ +""" +ZnNL: A Zincwarecode package. + +License +------- +This program and the accompanying materials are made available under the terms +of the Eclipse Public License v2.0 which accompanies this distribution, and is +available at https://www.eclipse.org/legal/epl-v20.html + +SPDX-License-Identifier: EPL-2.0 + +Copyright Contributors to the Zincwarecode Project. + +Contact Information +------------------- +email: zincwarecode@gmail.com +github: https://github.com/zincware +web: https://zincwarecode.com/ + +Citation +-------- +If you use this module please cite us with: + +Summary +------- +""" + +from typing import List + +from papyrus.measurements import BaseMeasurement +from papyrus.recorders import BaseRecorder + +from znnl.analysis import JAXNTKComputation +from znnl.models import JaxModel + + +class JaxRecorder(BaseRecorder): + """ + Recorder for a fixed dataset. + + Attributes + ---------- + name : str + The name of the recorder, defining the name of the file the data will be + stored in. + storage_path : str + The path to the storage location of the recorder. + measurements : List[BaseMeasurement] + The measurements that the recorder will apply. + chunk_size : int + The size of the chunks in which the data will be stored. + overwrite : bool (default=False) + Whether to overwrite the existing data in the database. + update_rate : int (default=1) + The rate at which the recorder will update the neural state. + neural_state_keys : List[str] + The keys of the neural state that the recorder takes as input. + A neural state is a dictionary of numpy arrays that represent the state of + a neural network. + neural_state : Dict[str, onp.ndarray] + The neural state dictionary containing a state representation of a neural + network. + _data_set : Dict[str, onp.ndarray] + The dataset that will be used to create the neural state. + It needs to be a dictionary of numpy arrays with the following keys: + - "inputs": The ionputs of the dataset. + - "targets": The targets of the dataset. + _model : JaxModel + A neural network module. For more information see the JaxModel class. + _ntk_computation : JAXNTKComputation + An NTK computation module. For more information see the JAXNTKComputation + class. + """ + + def __init__( + self, + name: str, + storage_path: str, + measurements: List[BaseMeasurement], + chunk_size: int = 1e5, + overwrite: bool = False, + update_rate: int = 1, + ): + """ + Constructor method of the BaseRecorder class. + + Parameters + ---------- + name : str + The name of the recorder, defining the name of the file the data will be + stored in. + storage_path : str + The path to the storage location of the recorder. + measurements : List[BaseMeasurement] + The measurements that the recorder will apply. + chunk_size : int (default=1e5) + The size of the chunks in which the data will be stored. + overwrite : bool (default=False) + Whether to overwrite the existing data in the database. + update_rate : int (default=1) + The rate at which the recorder will update the neural state. + """ + super().__init__(name, storage_path, measurements, chunk_size, overwrite) + self.update_rate = update_rate + + self.neural_state = {} + + self._data_set = None + self._model = None + self._ntk_computation = None + + def instantiate_recorder( + self, + model: JaxModel = None, + data_set: dict = None, + ntk_computation: JAXNTKComputation = None, + ): + """ + Prepare the recorder for training. + + Instantiate the recorder with the required modules and data set. + + The instantiation method performs the following checks: + - Check if the neural network module is required and provided. + - Check if the NTK computation module is required and provided. + - Check if the data set is provided. + + Parameters + ---------- + model : JaxModel (default=None) + The neural network module to record during training. + data_set : dict (default=None) + Data to record during training. The first key needs to be the input data + and the second key the target data. + ntk_computation : JAXNTKComputation (default=None) + Computation of the NTK matrix. + If the ntk is to be computed, this is required. + + Returns + ------- + Populates the array attributes of the dataclass. + """ + + # Check if the neural network module is required and provided + if "predictions" in self.neural_state_keys: + if model is None: + raise AttributeError( + "The neural network module is required for the recording process." + "Instantiate the recorder with a JaxModel." + ) + + # Check if the NTK computation module is required and provided + if "ntk" in self.neural_state_keys: + if ntk_computation is None: + raise AttributeError( + "The NTK computation module is required for the recording process." + "Instantiate the recorder with a JAXNTKComputation module." + ) + + # Check if the data set is provided + if data_set: + # Update simple attributes + self._data_set = data_set + if self._data_set is None and data_set is None: + raise AttributeError( + "No data set given for the recording process." + "Instantiate the recorder with a data set." + ) + + self._model = model + self._ntk_computation = ntk_computation + + def _check_keys(self): + """ + Check if the provided keys match the neural state keys. + + This method checks if the provided keys of the neural state match the required + keys of the neural state, in other words, if the incoming data is complete. + + Parameters + ---------- + neural_state : dict + The neural state dictionary. + """ + if any([key not in self.neural_state.keys() for key in self.neural_state_keys]): + raise KeyError( + "The attributes that are computed do not match the required attributes." + "The required attributes are: " + f"{self.neural_state_keys}." + "The provided attributes are: " + f"{self.neural_state.keys()}." + ) + + def _compute_neural_state(self, model: JaxModel): + """ + Compute the neural state. + + Parameters + ---------- + model : JaxModel + The neural network module containing the parameters to use for + recording. + + Returns + ------- + Updates the neural state dictionary. + """ + if self._ntk_computation: + ntk = self._ntk_computation.compute_ntk( + params={ + "params": model.model_state.params, + "batch_stats": model.model_state.batch_stats, + }, + dataset_i=self._data_set, + ) + self.neural_state["ntk"] = ntk + if self._model: + predictions = model(self._data_set[list(self._data_set.keys())[0]]) + self.neural_state["predictions"] = [predictions] + + def record(self, epoch: int, model: JaxModel, **kwargs): + """ + Perform the recording of a neural state. + + Recording is done by measuring and storing the measurements to a database. + + Parameters + ---------- + epoch : int + The epoch of the training process. + model : JaxModel + The neural network module containing the parameters to use for + recording. + kwargs : Any + Additional keyword arguments that are directly added to the neural + state. + + Returns + ------- + result : onp.ndarray + The result of the recorder. + """ + if epoch % self.update_rate == 0: + # Compute the neural state + self._compute_neural_state(model) + # Add all other kwargs to the neural state dictionary + self.neural_state.update(kwargs) + for key, val in self._data_set.items(): + self.neural_state[key] = [val] + # Check if incoming data is complete + self._check_keys() + # Perform measurements + self._measure(**self.neural_state) + # Store the measurements + self.store(ignore_chunk_size=False) diff --git a/znnl/training_strategies/loss_aware_reservoir.py b/znnl/training_strategies/loss_aware_reservoir.py index e679906..287f12d 100644 --- a/znnl/training_strategies/loss_aware_reservoir.py +++ b/znnl/training_strategies/loss_aware_reservoir.py @@ -35,6 +35,7 @@ from tqdm import trange from znnl.accuracy_functions.accuracy_function import AccuracyFunction +from znnl.analysis.jax_ntk import JAXNTKComputation from znnl.distance_metrics import DistanceMetric from znnl.models.jax_model import JaxModel from znnl.optimizers.trace_optimizer import TraceOptimizer @@ -422,6 +423,11 @@ def train_model( state = self.model.model_state self.train_data_size = len(train_ds["targets"]) + if isinstance(self.model.optimizer, TraceOptimizer): + ntk_computation = JAXNTKComputation( + self.model.ntk_apply_fn, trace_axes=(-1,), batch_size=batch_size + ) + loading_bar = trange( 1, epochs + 1, ncols=100, unit="batch", disable=self.disable_loading_bar ) @@ -432,7 +438,7 @@ def train_model( # Update the recorder properties if self.recorders is not None: for item in self.recorders: - item.update_recorder(epoch=i, model=self.model) + item.record(epoch=i, model=self.model) loading_bar.set_description(f"Epoch: {i}") @@ -449,7 +455,7 @@ def train_model( state = self.model.optimizer.apply_optimizer( model_state=state, data_set=train_ds["inputs"][full_dataset_idx], - ntk_fn=self.model.compute_ntk, + ntk_fn=ntk_computation.compute_ntk, epoch=i, ) diff --git a/znnl/training_strategies/partitioned_training.py b/znnl/training_strategies/partitioned_training.py index f2d48af..0b91180 100644 --- a/znnl/training_strategies/partitioned_training.py +++ b/znnl/training_strategies/partitioned_training.py @@ -32,6 +32,7 @@ from tqdm import trange from znnl.accuracy_functions.accuracy_function import AccuracyFunction +from znnl.analysis.jax_ntk import JAXNTKComputation from znnl.models.jax_model import JaxModel from znnl.optimizers.trace_optimizer import TraceOptimizer from znnl.training_recording import JaxRecorder @@ -262,6 +263,11 @@ def train_model( """ state = self.model.model_state + if isinstance(self.model.optimizer, TraceOptimizer): + ntk_computation = JAXNTKComputation( + self.model.ntk_apply_fn, trace_axes=(-1,), batch_size=batch_size + ) + loading_bar = trange( 1, np.sum(epochs) + 1, @@ -282,7 +288,7 @@ def train_model( # Update the recorder properties if self.recorders is not None: for item in self.recorders: - item.update_recorder(epoch=i, model=self.model) + item.record(epoch=i, model=self.model) loading_bar.set_description(f"Phase: {training_phase+1}, Epoch: {i}") @@ -297,7 +303,7 @@ def train_model( state = self.model.optimizer.apply_optimizer( model_state=state, data_set=train_data["inputs"], - ntk_fn=self.model.compute_ntk, + ntk_fn=ntk_computation.compute_ntk, epoch=i, ) diff --git a/znnl/training_strategies/simple_training.py b/znnl/training_strategies/simple_training.py index c0675e8..a53fc08 100644 --- a/znnl/training_strategies/simple_training.py +++ b/znnl/training_strategies/simple_training.py @@ -36,6 +36,7 @@ from tqdm import trange from znnl.accuracy_functions.accuracy_function import AccuracyFunction +from znnl.analysis.jax_ntk import JAXNTKComputation from znnl.models.jax_model import JaxModel from znnl.optimizers.trace_optimizer import TraceOptimizer from znnl.training_recording import JaxRecorder @@ -109,13 +110,6 @@ def __init__( self.review_metric = None - # Add the loss and accuracy function to the recorders and re-instantiate them - if self.recorders is not None: - for item in self.recorders: - item.loss_fn = loss_fn - item.accuracy_fn = accuracy_fn - item.instantiate_recorder() - # Initialize the train step self._train_step = None self._init_train_step() @@ -368,6 +362,11 @@ def train_model( """ state = self.model.model_state + if isinstance(self.model.optimizer, TraceOptimizer): + ntk_computation = JAXNTKComputation( + self.model.ntk_apply_fn, trace_axes=(-1,), batch_size=batch_size + ) + loading_bar = trange( 0, epochs, ncols=100, unit="batch", disable=self.disable_loading_bar ) @@ -378,7 +377,7 @@ def train_model( # Update the recorder properties if self.recorders is not None: for item in self.recorders: - item.update_recorder(epoch=i, model=self.model) + item.record(epoch=i, model=self.model) loading_bar.set_description(f"Epoch: {i}") @@ -386,7 +385,7 @@ def train_model( state = self.model.optimizer.apply_optimizer( model_state=state, data_set=train_ds["inputs"], - ntk_fn=self.model.compute_ntk, + ntk_fn=ntk_computation.compute_ntk, epoch=i, )