-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Konsti papyrus recording #121
Conversation
Adapt all tests
- Computing CVs Example - Contrastive Loss Example - Using data recorder Example
Adapt examples: - Using Training Strategies - ResNet Example
This allows for subsampling the ntk.
Add option to set the data set keys in the init of the ntk computation
…ation of a system.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The papyrus addition looks good. The additional recorders could have been put into a separate PR, but I read over them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be better to split the PRs but it looks good.
In this PR the JaxRecorder is replaced by a recorder implementation in the package papyrus.
For that the following changes will be applied:
More information
JAXNTKComputation
Re-implementation of the previous NTK computation located in each model.
The now has to set an apply function to construct the class. This allows for setting any function on which the NTK should be recorded.
JAXNTKSubsampling
An NTK computation that approximates the full matrix by computing diagonal blocks of the NTK using user-defined block sizes. The block size can be set via the arg
ntk_size
when constructing the method. The data of each block matrix is assigned randomly.NTKClassWise
An NTK computation that computes kernel for all samples of the same class. Given a data set of 10 classes, one obtains 10 NTKs.
NTKCombinations
An NTK computation that evaluates returns the Kernel for all possible class combinations. Given 2 classes, one obtains the NTK for the samples of class (0), (1) and (0+1). Which classes to be selected can be controlled at construction by setting the arg
class_labels
.