Re-implementation of the previous NTK computation located in each model.
The now has to set an apply function to construct the class. This allows for setting any function on which the NTK should be recorded.
JAXNTKSubsampling
An NTK computation that approximates the full matrix by computing diagonal blocks of the NTK using user-defined block sizes. The block size can be set via the arg ntk_size when constructing the method. The data of each block matrix is assigned randomly.
NTKClassWise
An NTK computation that computes kernel for all samples of the same class. Given a data set of 10 classes, one obtains 10 NTKs.
NTKCombinations
An NTK computation that evaluates returns the Kernel for all possible class combinations. Given 2 classes, one obtains the NTK for the samples of class (0), (1) and (0+1). Which classes to be selected can be controlled at construction by setting the arg class_labels.
In this PR the JaxRecorder is replaced by a recorder implementation in the package papyrus. For that the following changes will be applied:
More information
JAXNTKComputation
Re-implementation of the previous NTK computation located in each model.
The now has to set an apply function to construct the class. This allows for setting any function on which the NTK should be recorded.
JAXNTKSubsampling
An NTK computation that approximates the full matrix by computing diagonal blocks of the NTK using user-defined block sizes. The block size can be set via the arg
ntk_size
when constructing the method. The data of each block matrix is assigned randomly.NTKClassWise
An NTK computation that computes kernel for all samples of the same class. Given a data set of 10 classes, one obtains 10 NTKs.
NTKCombinations
An NTK computation that evaluates returns the Kernel for all possible class combinations. Given 2 classes, one obtains the NTK for the samples of class (0), (1) and (0+1). Which classes to be selected can be controlled at construction by setting the arg
class_labels
.