zhangxiangxiao / xjax

Simple framework for neural networks using Jax
BSD 3-Clause "New" or "Revised" License
6 stars 3 forks source link

Implementation of Recurrent Neural Network #1

Open banerjeeshubhanker opened 2 years ago

banerjeeshubhanker commented 2 years ago

Hi, Thank you for developing this awesome library. Can I take up the development of RNN for this library?

zhangxiangxiao commented 2 years ago

Yeah totally. Please take a look at the attention branch. Some initial design is there https://github.com/zhangxiangxiao/xjax/blob/attention/xjax/xar.py

Feel free to use the design or ignore it.

banerjeeshubhanker commented 2 years ago

Thank you, that's very helpful

banerjeeshubhanker commented 2 years ago

@zhangxiangxiao I've looking at the codebase and trying to implement a RNN layer to get started. I've come across GRUCell in flax which can be used as a building block https://flax.readthedocs.io/en/latest/_autosummary/flax.linen.GRUCell.html. Do you think the RNN layer needs to be implemented from scratch can we add a wrapper around the existing functionality in flax?

zhangxiangxiao commented 2 years ago

The purpose of XJAX is to avoid the complexity of other frameworks such as FLAX. The XJAX simplicity should include the readibility of internal implementations, by exposing jax numerics directly. Therefore we should not wrap on FLAX.

That said, the FLAX codebase is a great source of reference for the numerical details on how to implement a module from scratch.

banerjeeshubhanker commented 2 years ago

That makes perfect sense! Thank you