Open vimalthilak opened 3 years ago
Hi Vimal,
1) Yes, I'm afraid your understanding is correct, currently NT only compoutes exact infinite-width kernels of networks defined in nt.stax
. As mentioned in the thread you linked, @sschoenholz is looking into relaxing this constraint, but this is a challenging task and we don't have a precise timeline. So in the meantime you may need to either write a converter between nt.stax and FLAX models, or adapt the FLAX utilities to nt.stax (either can be quite laborious...).
2) Hard to say, I'm not familiar with it, and seems like they don't have documentation yet. I imagine some things that don't depend heavily on flax might work (maybe https://github.com/google/CommonLoopUtils/blob/master/clu/metrics.py?), but others more specific to flax may not (perhaps https://github.com/google/CommonLoopUtils/blob/master/clu/checkpoint.py?). So you may need to try it out in practice / read through their code to figure out. Sorry I can't be more helpful here! (If someone is more familiar with flax/CLU, please let us know if you have any ideas; one other place to ask could be https://github.com/google/jax regarding whether CLU would work with jax.experimental.stax
, which is very similar to nt.stax
).
Hello neutal-tangents (NT) authors,
Thanks for creating and maintaining such a great product.
I have a few questions related to usage:
Any help with the above is very much appreciated!