d2l-ai / d2l-en

Interactive deep learning book with multi-framework code, math, and discussions. Adopted at 500 universities from 70 countries including Stanford, MIT, Harvard, and Cambridge.
https://D2L.ai
Other
23.95k stars 4.36k forks source link

JAX/Flax Implementation #1825

Closed sazio closed 3 years ago

sazio commented 3 years ago

Have you already considered to add a JAX based (could be Flax for NNs) implementation as an alternative to MXNet, Tensorflow and Pytorch?

astonzhang commented 3 years ago

Yes. It's under planning but may take some time.

dbalabka commented 3 years ago

@astonzhang will you accept PRs with JAX implementations? 😎 I'm planning to read the book and try to reimplement code examples using JAX

astonzhang commented 3 years ago

@dbalabka Absolutely! We have a jax branch https://github.com/d2l-ai/d2l-en/tree/jax. As we are currently refactoring code for the entire book, it may save your efforts if you could adapt to JAX after the code refactoring :)

dbalabka commented 3 years ago

@astonzhang could you please clarify where I can track refactoring progress (e.g., branch or ticket)?

astonzhang commented 3 years ago

@dbalabka You may watch this repo to stay updated with our release notes

dbalabka commented 3 years ago

@astonzhang is it possible to update Jax branch to make it up–to–date with the latest changes from master? I can help with that via PR if needed.

astonzhang commented 2 years ago

@astonzhang is it possible to update Jax branch to make it up–to–date with the latest changes from master? I can help with that via PR if needed.

Thanks! The refactoring is still ongoing (e.g., see breaking changes in the pending PRs) and we'll update the Jax branch once it's in a better shape.

NightMachinery commented 2 years ago

I suggest going with either pure JAX or Haiku; Flax has too much magic and is not as good for a pedagogical choice. Haiku abides by Python’s mantra that “explicit is better.”

dbalabka commented 2 years ago

@NightMachinary , my initial idea was to keep the code much closer to TensorFlow and PyTorch. Using pure Jax will lead to the need to reimplement every layer that will be problematic to support. It is better to keep it simple. All NumPy level code snippets implement with Jax. For neural networks code snippets use Flax, Trax, or Haiku. AFAIK Flax is the most popular and based on Github statistics most useable: https://github.com/google/flax/network/dependents So, I would start with Flax and then implement any other of the above-listed libraries.