Open VMS-6511 opened 1 year ago
I'm looking to use the library to compute the after kernel for a model trained with the FLAX library? I followed this Colab: https://colab.research.google.com/github/google/neural-tangents/blob/main/notebooks/empirical_ntk_resnet.ipynb.
Instead of these lines:
params = model.init(random.PRNGKey(0), x1) return params, (jacobian_contraction, ntvp, str_derivatives, auto) params, (ntk_fn_jacobian_contraction, ntk_fn_ntvp, ntk_fn_str_derivatives, ntk_fn_auto) = get_ntk_fns(O=O) k_1 = ntk_fn_jacobian_contraction(x1, x2, params)
I used the params from the following TrainState of the FLAX model:
state = TrainState.create( apply_fn = model.apply, params = variables['params'], batch_stats = variables['batch_stats'], tx = tx)
I was wondering if this is the correct way to do this? Thanks!
Hi Vinith, yes I think this is correct, if something isn't working as expected let me know!
I'm looking to use the library to compute the after kernel for a model trained with the FLAX library? I followed this Colab: https://colab.research.google.com/github/google/neural-tangents/blob/main/notebooks/empirical_ntk_resnet.ipynb.
Instead of these lines:
I used the params from the following TrainState of the FLAX model:
I was wondering if this is the correct way to do this? Thanks!