diff --git a/CI/unit_tests/observables/test_fisher_trace_calculation.py b/CI/unit_tests/observables/test_fisher_trace_calculation.py new file mode 100644 index 0000000..bbe09fe --- /dev/null +++ b/CI/unit_tests/observables/test_fisher_trace_calculation.py @@ -0,0 +1,59 @@ +""" +ZnNL: A Zincwarecode package. + +License +------- +This program and the accompanying materials are made available under the terms +of the Eclipse Public License v2.0 which accompanies this distribution, and is +available at https://www.eclipse.org/legal/epl-v20.html + +SPDX-License-Identifier: EPL-2.0 + +Copyright Contributors to the Zincwarecode Project. + +Contact Information +------------------- +email: zincwarecode@gmail.com +github: https://github.com/zincware +web: https://zincwarecode.com/ + +Citation +-------- +If you use this module please cite us with: + +Summary +------- +This module tests the implementation of the fisher trace computation module. +""" + +import numpy as np + +from znnl.observables.fisher_trace import compute_fisher_trace + + +class TestFisherTrace: + """ + Class for testing the implementation of the fisher trace calculation + """ + + def test_fisher_trace_computation(self): + """ + Function tests if the fisher trace computation works correctly for an + example which was calculated by hand before. + + Returns + ------- + Asserts the calculated fisher trace for the manually defined inputs + is what it should be. + """ + + ntk = np.array( + [ + [[[1, 2, 3], [4, 5, 6], [7, 8, 9]], np.random.rand(3, 3)], + [np.random.rand(3, 3), [[2, 1, 3], [1, 2, 3], [3, 2, 1]]], + ] + ) + loss_derivative = np.array([[5, 4, 3], [2, 1, 0]]) + + trace = compute_fisher_trace(loss_derivative=loss_derivative, ntk=ntk) + assert trace == 638 / 2 diff --git a/znnl/observables/__init__.py b/znnl/observables/__init__.py new file mode 100644 index 0000000..180060c --- /dev/null +++ b/znnl/observables/__init__.py @@ -0,0 +1,30 @@ +""" +ZnNL: A Zincwarecode package. + +License +------- +This program and the accompanying materials are made available under the terms +of the Eclipse Public License v2.0 which accompanies this distribution, and is +available at https://www.eclipse.org/legal/epl-v20.html + +SPDX-License-Identifier: EPL-2.0 + +Copyright Contributors to the Zincwarecode Project. + +Contact Information +------------------- +email: zincwarecode@gmail.com +github: https://github.com/zincware +web: https://zincwarecode.com/ + +Citation +-------- +If you use this module please cite us with: + +Summary +------- +Module for the observables. +""" +from znnl.observables.fisher_trace import compute_fisher_trace + +__all__ = [compute_fisher_trace.__name__] diff --git a/znnl/observables/fisher_trace.py b/znnl/observables/fisher_trace.py new file mode 100644 index 0000000..398a1b8 --- /dev/null +++ b/znnl/observables/fisher_trace.py @@ -0,0 +1,73 @@ +""" +ZnNL: A Zincwarecode package. + +License +------- +This program and the accompanying materials are made available under the terms +of the Eclipse Public License v2.0 which accompanies this distribution, and is +available at https://www.eclipse.org/legal/epl-v20.html + +SPDX-License-Identifier: EPL-2.0 + +Copyright Contributors to the Zincwarecode Project. + +Contact Information +------------------- +email: zincwarecode@gmail.com +github: https://github.com/zincware +web: https://zincwarecode.com/ + +Citation +-------- +If you use this module please cite us with: + +Summary +------- +Module for the computation of the Fisher trace. +""" +import jax +import jax.numpy as np + + +def compute_fisher_trace(loss_derivative: np.ndarray, ntk: np.ndarray) -> float: + """ + Compute the Fisher matrix trace from the NTK. + + Parameters + ---------- + loss_derivative : np.ndarray (n_data_points, network_output) + Loss derivative to use in the computation. + ntk : np.ndarray (n_data_points, n_data_points, network_output, network_output) + NTK of the network in one state. + + Returns + ------- + fisher_trace : float + Trace of the Fisher matrix corresponding to the NTK. + """ + try: + assert len(ntk.shape) == 4 + except AssertionError: + raise TypeError( + "The ntk needs to be rank 4 for the fisher trace calculation." + "Maybe you have set the model to trace over the output dimensions?" + "Try adding trace_axes=() to the models parameters." + ) + + def _inner_fn(a, b, c): + """ + Function to be mapped over. + """ + return a * b * c + + map_1 = jax.vmap(_inner_fn, in_axes=(None, 0, 0)) + map_2 = jax.vmap(map_1, in_axes=(0, None, 0)) + map_3 = jax.vmap(map_2, in_axes=(0, 0, 0)) + + dataset_size = loss_derivative.shape[0] + indices = np.arange(dataset_size) + fisher_trace = np.sum( + map_3(loss_derivative, loss_derivative, ntk[indices, indices, :, :]) + ) + + return fisher_trace / dataset_size diff --git a/znnl/training_recording/jax_recording.py b/znnl/training_recording/jax_recording.py index 90b01d6..ee49bc7 100644 --- a/znnl/training_recording/jax_recording.py +++ b/znnl/training_recording/jax_recording.py @@ -37,12 +37,9 @@ from znnl.analysis.loss_fn_derivative import LossDerivative from znnl.loss_functions import SimpleLoss from znnl.models.jax_model import JaxModel +from znnl.observables.fisher_trace import compute_fisher_trace from znnl.training_recording.data_storage import DataStorage -from znnl.utils.matrix_utils import ( - calculate_l_pq_norm, - compute_magnitude_density, - normalize_gram_matrix, -) +from znnl.utils.matrix_utils import compute_magnitude_density, normalize_gram_matrix logger = logging.getLogger(__name__) @@ -90,6 +87,10 @@ class JaxRecorder: loss_derivative : bool (default=False) If true, the derivative of the loss function with respect to the network output will be recorded. + fisher_trace : bool (default=False) + If true, the trace of the fisher matrix will be recorded. Requires the ntk + and the loss derivative to be calculated. + Warning, large overhead. update_rate : int (default=1) How often the values are updated. @@ -148,6 +149,10 @@ class JaxRecorder: loss_derivative: bool = False _loss_derivative_array: list = None + # Fisher trace + fisher_trace: bool = False + _fisher_trace_array: list = None + # Class helpers update_rate: int = 1 _loss_fn: SimpleLoss = None @@ -244,7 +249,7 @@ def instantiate_recorder(self, data_set: dict = None, overwrite: bool = False): if overwrite: self._index_count = 0 - # Check if we need an NTK computation and update the class accordingly + # Check if we need an NTK computation, update the class accordingly. if any( [ "ntk" in self._selected_properties, @@ -255,11 +260,19 @@ def instantiate_recorder(self, data_set: dict = None, overwrite: bool = False): "covariance_entropy" in self._selected_properties, "eigenvalues" in self._selected_properties, "trace" in self._selected_properties, + "fisher_trace" in self._selected_properties, ] ): self._compute_ntk = True - if "loss_derivative" in self._selected_properties: + # Check if we need a loss derivative computation, update the class accordingly + if any( + [ + "fisher_trace" in self._selected_properties, + "loss_derivative" in self._selected_properties, + ] + ): + self._compute_loss_derivative = True self._loss_derivative_fn = LossDerivative(self._loss_fn) def update_recorder(self, epoch: int, model: JaxModel): @@ -291,7 +304,7 @@ def update_recorder(self, epoch: int, model: JaxModel): predictions = predictions[0] parsed_data["predictions"] = predictions - # Compute ntk here to avoid repeated computation. + # Compute ntk and loss derivative here to avoid repeated computation. if self._compute_ntk: try: ntk = self._model.compute_ntk( @@ -311,6 +324,11 @@ def update_recorder(self, epoch: int, model: JaxModel): self.covariance_entropy = False self.eigenvalues = False self._read_selected_attributes() + if self._compute_loss_derivative: + vector_loss_derivative = self._loss_derivative_fn.calculate( + parsed_data["predictions"], self._data_set["targets"] + ) + parsed_data["loss_derivative"] = vector_loss_derivative for item in self._selected_properties: call_fn = getattr(self, f"_update_{item}") # get the callable function @@ -342,7 +360,7 @@ def visualize_recorder(self): ------- """ - raise NotImplementedError("Not yet available in ZnRND.") + raise NotImplementedError("Not yet available in ZnNL.") @property def loss_fn(self): @@ -538,18 +556,28 @@ def _update_loss_derivative(self, parsed_data): """ Update the loss derivative array. - The loss derivative is normalized by the L_pq matrix norm. + Parameters + ---------- + parsed_data : dict + Data computed before the update to prevent repeated calculations. + """ + self._loss_derivative_array.append(parsed_data["loss_derivative"]) + + def _update_fisher_trace(self, parsed_data): + """ + Update the fisher trace array. Parameters ---------- parsed_data : dict Data computed before the update to prevent repeated calculations. """ - vector_loss_derivative = self._loss_derivative_fn.calculate( - parsed_data["predictions"], self._data_set["targets"] - ) - loss_derivative = calculate_l_pq_norm(vector_loss_derivative) - self._loss_derivative_array.append(loss_derivative) + loss_derivative = parsed_data["loss_derivative"] + ntk = parsed_data["ntk"] + + fisher_trace = compute_fisher_trace(loss_derivative=loss_derivative, ntk=ntk) + + self._fisher_trace_array.append(fisher_trace) def gather_recording(self, selected_properties: list = None) -> dataclass: """ @@ -593,6 +621,7 @@ def gather_recording(self, selected_properties: list = None) -> dataclass: db_data = self._data_storage.fetch_data(selected_properties) # Add db data to the selected data dict. for item, data in selected_data.items(): + print(item) selected_data[item] = onp.concatenate((db_data[item], data), axis=0) except FileNotFoundError: # There is no database.