Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Record Fisher trace #92

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
59 changes: 59 additions & 0 deletions CI/unit_tests/observables/test_fisher_trace_calculation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""
ZnNL: A Zincwarecode package.

License
-------
This program and the accompanying materials are made available under the terms
of the Eclipse Public License v2.0 which accompanies this distribution, and is
available at https://www.eclipse.org/legal/epl-v20.html

SPDX-License-Identifier: EPL-2.0

Copyright Contributors to the Zincwarecode Project.

Contact Information
-------------------
email: [email protected]
github: https://github.com/zincware
web: https://zincwarecode.com/

Citation
--------
If you use this module please cite us with:

Summary
-------
This module tests the implementation of the fisher trace computation module.
"""

import numpy as np

from znnl.observables.fisher_trace import compute_fisher_trace


class TestFisherTrace:
"""
Class for testing the implementation of the fisher trace calculation
"""

def test_fisher_trace_computation(self):
"""
Function tests if the fisher trace computation works correctly for an
example which was calculated by hand before.

Returns
-------
Asserts the calculated fisher trace for the manually defined inputs
is what it should be.
"""

ntk = np.array(
[
[[[1, 2, 3], [4, 5, 6], [7, 8, 9]], np.random.rand(3, 3)],
[np.random.rand(3, 3), [[2, 1, 3], [1, 2, 3], [3, 2, 1]]],
]
)
loss_derivative = np.array([[5, 4, 3], [2, 1, 0]])

trace = compute_fisher_trace(loss_derivative=loss_derivative, ntk=ntk)
assert trace == 638 / 2
30 changes: 30 additions & 0 deletions znnl/observables/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""
ZnNL: A Zincwarecode package.

License
-------
This program and the accompanying materials are made available under the terms
of the Eclipse Public License v2.0 which accompanies this distribution, and is
available at https://www.eclipse.org/legal/epl-v20.html

SPDX-License-Identifier: EPL-2.0

Copyright Contributors to the Zincwarecode Project.

Contact Information
-------------------
email: [email protected]
github: https://github.com/zincware
web: https://zincwarecode.com/

Citation
--------
If you use this module please cite us with:

Summary
-------
Module for the observables.
"""
from znnl.observables.fisher_trace import compute_fisher_trace

__all__ = [compute_fisher_trace.__name__]
73 changes: 73 additions & 0 deletions znnl/observables/fisher_trace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""
ZnNL: A Zincwarecode package.

License
-------
This program and the accompanying materials are made available under the terms
of the Eclipse Public License v2.0 which accompanies this distribution, and is
available at https://www.eclipse.org/legal/epl-v20.html

SPDX-License-Identifier: EPL-2.0

Copyright Contributors to the Zincwarecode Project.

Contact Information
-------------------
email: [email protected]
github: https://github.com/zincware
web: https://zincwarecode.com/

Citation
--------
If you use this module please cite us with:

Summary
-------
Module for the computation of the Fisher trace.
"""
import jax
import jax.numpy as np


def compute_fisher_trace(loss_derivative: np.ndarray, ntk: np.ndarray) -> float:
"""
Compute the Fisher matrix trace from the NTK.

Parameters
----------
loss_derivative : np.ndarray (n_data_points, network_output)
Loss derivative to use in the computation.
ntk : np.ndarray (n_data_points, n_data_points, network_output, network_output)
NTK of the network in one state.

Returns
-------
fisher_trace : float
Trace of the Fisher matrix corresponding to the NTK.
"""
try:
assert len(ntk.shape) == 4
except AssertionError:
raise TypeError(
"The ntk needs to be rank 4 for the fisher trace calculation."
"Maybe you have set the model to trace over the output dimensions?"
"Try adding trace_axes=() to the models parameters."
)

def _inner_fn(a, b, c):
"""
Function to be mapped over.
"""
return a * b * c

map_1 = jax.vmap(_inner_fn, in_axes=(None, 0, 0))
map_2 = jax.vmap(map_1, in_axes=(0, None, 0))
map_3 = jax.vmap(map_2, in_axes=(0, 0, 0))

dataset_size = loss_derivative.shape[0]
indices = np.arange(dataset_size)
fisher_trace = np.sum(
map_3(loss_derivative, loss_derivative, ntk[indices, indices, :, :])
)

return fisher_trace / dataset_size
59 changes: 44 additions & 15 deletions znnl/training_recording/jax_recording.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,9 @@
from znnl.analysis.loss_fn_derivative import LossDerivative
from znnl.loss_functions import SimpleLoss
from znnl.models.jax_model import JaxModel
from znnl.observables.fisher_trace import compute_fisher_trace
from znnl.training_recording.data_storage import DataStorage
from znnl.utils.matrix_utils import (
calculate_l_pq_norm,
compute_magnitude_density,
normalize_gram_matrix,
)
from znnl.utils.matrix_utils import compute_magnitude_density, normalize_gram_matrix

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -90,6 +87,10 @@ class JaxRecorder:
loss_derivative : bool (default=False)
If true, the derivative of the loss function with respect to the network
output will be recorded.
fisher_trace : bool (default=False)
If true, the trace of the fisher matrix will be recorded. Requires the ntk
and the loss derivative to be calculated.
Warning, large overhead.
update_rate : int (default=1)
How often the values are updated.

Expand Down Expand Up @@ -148,6 +149,10 @@ class JaxRecorder:
loss_derivative: bool = False
_loss_derivative_array: list = None

# Fisher trace
fisher_trace: bool = False
_fisher_trace_array: list = None

# Class helpers
update_rate: int = 1
_loss_fn: SimpleLoss = None
Expand Down Expand Up @@ -244,7 +249,7 @@ def instantiate_recorder(self, data_set: dict = None, overwrite: bool = False):
if overwrite:
self._index_count = 0

# Check if we need an NTK computation and update the class accordingly
# Check if we need an NTK computation, update the class accordingly.
if any(
[
"ntk" in self._selected_properties,
Expand All @@ -255,11 +260,19 @@ def instantiate_recorder(self, data_set: dict = None, overwrite: bool = False):
"covariance_entropy" in self._selected_properties,
"eigenvalues" in self._selected_properties,
"trace" in self._selected_properties,
"fisher_trace" in self._selected_properties,
]
):
self._compute_ntk = True

if "loss_derivative" in self._selected_properties:
# Check if we need a loss derivative computation, update the class accordingly
if any(
[
"fisher_trace" in self._selected_properties,
"loss_derivative" in self._selected_properties,
]
):
self._compute_loss_derivative = True
self._loss_derivative_fn = LossDerivative(self._loss_fn)

def update_recorder(self, epoch: int, model: JaxModel):
Expand Down Expand Up @@ -291,7 +304,7 @@ def update_recorder(self, epoch: int, model: JaxModel):
predictions = predictions[0]
parsed_data["predictions"] = predictions

# Compute ntk here to avoid repeated computation.
# Compute ntk and loss derivative here to avoid repeated computation.
if self._compute_ntk:
try:
ntk = self._model.compute_ntk(
Expand All @@ -311,6 +324,11 @@ def update_recorder(self, epoch: int, model: JaxModel):
self.covariance_entropy = False
self.eigenvalues = False
self._read_selected_attributes()
if self._compute_loss_derivative:
vector_loss_derivative = self._loss_derivative_fn.calculate(
parsed_data["predictions"], self._data_set["targets"]
)
parsed_data["loss_derivative"] = vector_loss_derivative

for item in self._selected_properties:
call_fn = getattr(self, f"_update_{item}") # get the callable function
Expand Down Expand Up @@ -342,7 +360,7 @@ def visualize_recorder(self):
-------

"""
raise NotImplementedError("Not yet available in ZnRND.")
raise NotImplementedError("Not yet available in ZnNL.")

@property
def loss_fn(self):
Expand Down Expand Up @@ -538,18 +556,28 @@ def _update_loss_derivative(self, parsed_data):
"""
Update the loss derivative array.

The loss derivative is normalized by the L_pq matrix norm.
Parameters
----------
parsed_data : dict
Data computed before the update to prevent repeated calculations.
"""
self._loss_derivative_array.append(parsed_data["loss_derivative"])

def _update_fisher_trace(self, parsed_data):
"""
Update the fisher trace array.

Parameters
----------
parsed_data : dict
Data computed before the update to prevent repeated calculations.
"""
vector_loss_derivative = self._loss_derivative_fn.calculate(
parsed_data["predictions"], self._data_set["targets"]
)
loss_derivative = calculate_l_pq_norm(vector_loss_derivative)
self._loss_derivative_array.append(loss_derivative)
loss_derivative = parsed_data["loss_derivative"]
ntk = parsed_data["ntk"]

fisher_trace = compute_fisher_trace(loss_derivative=loss_derivative, ntk=ntk)

self._fisher_trace_array.append(fisher_trace)

def gather_recording(self, selected_properties: list = None) -> dataclass:
"""
Expand Down Expand Up @@ -593,6 +621,7 @@ def gather_recording(self, selected_properties: list = None) -> dataclass:
db_data = self._data_storage.fetch_data(selected_properties)
# Add db data to the selected data dict.
for item, data in selected_data.items():
print(item)
selected_data[item] = onp.concatenate((db_data[item], data), axis=0)

except FileNotFoundError: # There is no database.
Expand Down