google / neural-tangents

Fast and Easy Infinite Neural Networks in Python
https://iclr.cc/virtual_2020/poster_SklD9yrFPS.html
Apache License 2.0
2.28k stars 226 forks source link

neural tangents vs flax #162

Closed po-oya closed 2 years ago

po-oya commented 2 years ago

Hi

I am new to JAX, and my question is can we use this package to train regular neural networks without the NTK notion? If this is the case, are there any examples or a document to start with?

Thanks!

romanngg commented 2 years ago

Yes, check out this description https://github.com/google/neural-tangents#5-minute-intro and an example in https://colab.research.google.com/github/google/neural-tangents/blob/main/notebooks/weight_space_linearization.ipynb

In short, in neural tangents you always define both finite and infinite networks simultaneously, e.g. in https://github.com/google/neural-tangents#infinitely-wideresnet

init_fn, apply_fn, kernel_fn = WideResnet(block_size=4, k=1, num_classes=10)

init_fn and apply_fn define your finite-width network, and you can use them however you want without caring about the NTK. They are counterparts to FLAX's model.init and model.apply.

Sohl-Dickstein commented 2 years ago

Note that if all you want to do is train finite-width neural networks, and you are not interested in computing the network's NTK or NNGP kernel or a linearization of the network, then FLAX would be a better choice.

On Wed, Aug 17, 2022, 08:25 Roman Novak @.***> wrote:

Yes, check out this description https://github.com/google/neural-tangents#5-minute-intro and an example in https://colab.research.google.com/github/google/neural-tangents/blob/main/notebooks/weight_space_linearization.ipynb

In short, in neural tangents you always define both finite and infinite networks simultaneously, e.g. in https://github.com/google/neural-tangents#infinitely-wideresnet

init_fn, apply_fn, kernel_fn = WideResnet(block_size=4, k=1, num_classes=10)

init_fn and apply_fn define your finite-width network, and you can use them however you want without caring about the NTK. They are counterparts to FLAX's model.init and model.apply.

— Reply to this email directly, view it on GitHub https://github.com/google/neural-tangents/issues/162#issuecomment-1218160373, or unsubscribe https://github.com/notifications/unsubscribe-auth/AADZW4AB2FSGWA5DFLCDHPTVZT76RANCNFSM56Z2SLIA . You are receiving this because you are subscribed to this thread.Message ID: @.***>

po-oya commented 2 years ago

Thanks for your help!