Skip to content

Commit

Permalink
Fix bugs in jax combinations
Browse files Browse the repository at this point in the history
  • Loading branch information
knikolaou committed May 31, 2024
1 parent 6e0fad9 commit c10fcdb
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
6 changes: 3 additions & 3 deletions znnl/analysis/jax_ntk_classwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def _compute_ntk(self, params: dict, x_i: np.ndarray) -> np.ndarray:
ntk = self._check_shape(ntk)
return ntk

def compute_ntk(self, params: dict, dataset: dict) -> List[np.ndarray]:
def compute_ntk(self, params: dict, dataset_i: dict) -> List[np.ndarray]:
"""
Compute the Neural Tangent Kernel (NTK) for the neural network.
Expand All @@ -223,9 +223,9 @@ def compute_ntk(self, params: dict, dataset: dict) -> List[np.ndarray]:
The NTK matrix.
"""

self._sample_indices = self._get_label_indices(dataset)
self._sample_indices = self._get_label_indices(dataset_i)

x_i = self._subsample_data(dataset[self.data_keys[0]], self._sample_indices)
x_i = self._subsample_data(dataset_i[self.data_keys[0]], self._sample_indices)

ntks = jmap(lambda x_i: self._compute_ntk(params, x_i), x_i)

Expand Down
6 changes: 3 additions & 3 deletions znnl/analysis/jax_ntk_combinations.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def _compute_ntk(self, params: dict, x_i: np.ndarray) -> np.ndarray:
ntk = self._check_shape(ntk)
return ntk

def compute_ntk(self, params: dict, dataset: dict) -> List[np.ndarray]:
def compute_ntk(self, params: dict, dataset_i: dict) -> List[np.ndarray]:
"""
Compute the Neural Tangent Kernel (NTK) for the neural network.
Expand All @@ -286,7 +286,7 @@ def compute_ntk(self, params: dict, dataset: dict) -> List[np.ndarray]:
----------
params : dict
The parameters of the neural network.
dataset : dict
dataset_i : dict
The input dataset for the NTK computation.
Returns
Expand All @@ -298,7 +298,7 @@ def compute_ntk(self, params: dict, dataset: dict) -> List[np.ndarray]:
"""

# Reduce the dataset to the selected class labels
dataset_reduced = self._reduce_data_to_labels(dataset)
dataset_reduced = self._reduce_data_to_labels(dataset_i)

# Compute the NTK for the reduced dataset
ntk = self._compute_ntk(params, dataset_reduced[self.data_keys[0]])
Expand Down

0 comments on commit c10fcdb

Please sign in to comment.