Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Konsti papyrus recording #121

Merged
merged 40 commits into from
Jun 10, 2024
Merged

Konsti papyrus recording #121

merged 40 commits into from
Jun 10, 2024

Conversation

KonstiNik
Copy link
Member

@KonstiNik KonstiNik commented Jun 8, 2024

In this PR the JaxRecorder is replaced by a recorder implementation in the package papyrus.
For that the following changes will be applied:

  • Create a class for Recording with the papyrus package, see JaxRecorder in znnl.training_recording.papyrus_jax_recording.py.
  • Move the NTK computation from a model attribute to an independent class in the analysis module: JAXNTKComputation
  • Create NTK sampling strategies: JAXNTKSubsampling, JAXNTKClassWise, JAXNTKCombinations

More information

JAXNTKComputation

Re-implementation of the previous NTK computation located in each model.
The now has to set an apply function to construct the class. This allows for setting any function on which the NTK should be recorded.

JAXNTKSubsampling

An NTK computation that approximates the full matrix by computing diagonal blocks of the NTK using user-defined block sizes. The block size can be set via the arg ntk_size when constructing the method. The data of each block matrix is assigned randomly.

NTKClassWise

An NTK computation that computes kernel for all samples of the same class. Given a data set of 10 classes, one obtains 10 NTKs.

NTKCombinations

An NTK computation that evaluates returns the Kernel for all possible class combinations. Given 2 classes, one obtains the NTK for the samples of class (0), (1) and (0+1). Which classes to be selected can be controlled at construction by setting the arg class_labels.

knikolaou added 30 commits May 10, 2024 16:48
- Computing CVs Example
- Contrastive Loss Example
- Using data recorder Example
Adapt examples:
- Using Training Strategies
- ResNet Example
This allows for subsampling the ntk.
Add option to set the data set keys in the init of the ntk computation
@KonstiNik KonstiNik requested a review from SamTov June 8, 2024 20:13
Copy link
Member

@SamTov SamTov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The papyrus addition looks good. The additional recorders could have been put into a separate PR, but I read over them.

.github/workflows/nbtest.yml Outdated Show resolved Hide resolved
Copy link
Member

@SamTov SamTov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be better to split the PRs but it looks good.

@KonstiNik KonstiNik merged commit 30c1e70 into main Jun 10, 2024
6 checks passed
@KonstiNik KonstiNik deleted the Konsti_papyrus_recording branch December 13, 2024 09:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants