Open banerjeeshubhanker opened 2 years ago
Yeah totally. Please take a look at the attention branch. Some initial design is there https://github.com/zhangxiangxiao/xjax/blob/attention/xjax/xar.py
Feel free to use the design or ignore it.
Thank you, that's very helpful
@zhangxiangxiao I've looking at the codebase and trying to implement a RNN layer to get started. I've come across GRUCell in flax which can be used as a building block https://flax.readthedocs.io/en/latest/_autosummary/flax.linen.GRUCell.html. Do you think the RNN layer needs to be implemented from scratch can we add a wrapper around the existing functionality in flax?
The purpose of XJAX is to avoid the complexity of other frameworks such as FLAX. The XJAX simplicity should include the readibility of internal implementations, by exposing jax numerics directly. Therefore we should not wrap on FLAX.
That said, the FLAX codebase is a great source of reference for the numerical details on how to implement a module from scratch.
That makes perfect sense! Thank you
Hi, Thank you for developing this awesome library. Can I take up the development of RNN for this library?