zincware / ZnNL

A Python package for studying neural learning
Eclipse Public License 2.0
6 stars 1 forks source link

Konsti papyrus recording #121

Closed KonstiNik closed 5 months ago

KonstiNik commented 5 months ago

In this PR the JaxRecorder is replaced by a recorder implementation in the package papyrus. For that the following changes will be applied:

More information

JAXNTKComputation

Re-implementation of the previous NTK computation located in each model.
The now has to set an apply function to construct the class. This allows for setting any function on which the NTK should be recorded.

JAXNTKSubsampling

An NTK computation that approximates the full matrix by computing diagonal blocks of the NTK using user-defined block sizes. The block size can be set via the arg ntk_size when constructing the method. The data of each block matrix is assigned randomly.

NTKClassWise

An NTK computation that computes kernel for all samples of the same class. Given a data set of 10 classes, one obtains 10 NTKs.

NTKCombinations

An NTK computation that evaluates returns the Kernel for all possible class combinations. Given 2 classes, one obtains the NTK for the samples of class (0), (1) and (0+1). Which classes to be selected can be controlled at construction by setting the arg class_labels.