zincware / ZnNL

A Python package for studying neural learning
Eclipse Public License 2.0
6 stars 1 forks source link

More flexibility in NTK recording #110

Open KonstiNik opened 9 months ago

KonstiNik commented 9 months ago

The current implementation of the JaxRecorder by default takes the ntk_apply function defined through the model. For some NTK computations (like the loss ntk or some fisher ntk approximations) the ntk apply function does not correspond to the model apply function. It would therefore be reasonable for the user to be able to set this function manually for each recorder, as one might like to record multiple versions of the ntk of one training.

The suggestion is to move the ntk_apply function from the model to a separate class. It handles all the ntk computation and is constructed by taking an apply function (of the model e.g.). We would need one NTK computation class for each model.

SamTov commented 9 months ago

Do you think one single NTK function should handle all your different versions of it? As the "loss ntk" or fisher are different things, it would make more sense to have a loss ntk or Fisher calculator somewhere and not overload the single NTK computation with a bunch of additional arguments and options. They can share a backend though.

KonstiNik commented 9 months ago

I agree. The point I was trying to make is more that the ntk computation currently is part of the model. In case you want to record an ntk that is not the exact model function, we cannot do this a.t.m. An example would be you want to record the ntk of the model with softmax output but the loss is softmax-cross-entropy which has the softmax already included. Another example is to record the ntk of the model + loss function. Moving the ntk computation from the model into a separate class does affect the recorders, as the ntk computation would be passed as a callable to the recorders.