Closed sazio closed 3 years ago
Yes. It's under planning but may take some time.
@astonzhang will you accept PRs with JAX implementations? đ I'm planning to read the book and try to reimplement code examples using JAX
@dbalabka Absolutely! We have a jax branch https://github.com/d2l-ai/d2l-en/tree/jax. As we are currently refactoring code for the entire book, it may save your efforts if you could adapt to JAX after the code refactoring :)
@astonzhang could you please clarify where I can track refactoring progress (e.g., branch or ticket)?
@dbalabka You may watch this repo to stay updated with our release notes
@astonzhang is it possible to update Jax branch to make it upâtoâdate with the latest changes from master
? I can help with that via PR if needed.
@astonzhang is it possible to update Jax branch to make it upâtoâdate with the latest changes from
master
? I can help with that via PR if needed.
Thanks! The refactoring is still ongoing (e.g., see breaking changes in the pending PRs) and we'll update the Jax branch once it's in a better shape.
I suggest going with either pure JAX or Haiku; Flax has too much magic and is not as good for a pedagogical choice. Haiku abides by Pythonâs mantra that âexplicit is better.â
@NightMachinary , my initial idea was to keep the code much closer to TensorFlow and PyTorch. Using pure Jax will lead to the need to reimplement every layer that will be problematic to support. It is better to keep it simple. All NumPy level code snippets implement with Jax. For neural networks code snippets use Flax, Trax, or Haiku. AFAIK Flax is the most popular and based on Github statistics most useable: https://github.com/google/flax/network/dependents So, I would start with Flax and then implement any other of the above-listed libraries.
Have you already considered to add a JAX based (could be Flax for NNs) implementation as an alternative to MXNet, Tensorflow and Pytorch?