google / neural-tangents

Fast and Easy Infinite Neural Networks in Python
https://iclr.cc/virtual_2020/poster_SklD9yrFPS.html
Apache License 2.0
2.28k stars 226 forks source link

How to compute the empirical after kernel? #189

Open VMS-6511 opened 1 year ago

VMS-6511 commented 1 year ago

I'm looking to use the library to compute the after kernel for a model trained with the FLAX library? I followed this Colab: https://colab.research.google.com/github/google/neural-tangents/blob/main/notebooks/empirical_ntk_resnet.ipynb.

Instead of these lines:

  params = model.init(random.PRNGKey(0), x1)
  return params, (jacobian_contraction, ntvp, str_derivatives, auto)

params, (ntk_fn_jacobian_contraction, ntk_fn_ntvp, ntk_fn_str_derivatives, ntk_fn_auto) = get_ntk_fns(O=O)
k_1 = ntk_fn_jacobian_contraction(x1, x2, params)

I used the params from the following TrainState of the FLAX model:

state = TrainState.create(
    apply_fn = model.apply,
    params = variables['params'],
    batch_stats = variables['batch_stats'],
    tx = tx)

I was wondering if this is the correct way to do this? Thanks!

romanngg commented 1 year ago

Hi Vinith, yes I think this is correct, if something isn't working as expected let me know!