NicolasZucchet / minimal-LRU

Non official implementation of the Linear Recurrent Unit (LRU, Orvieto et al. 2023)
MIT License
48 stars 4 forks source link

Linear Vanilla RNNs Code #1

Closed yuqinzhou9 closed 1 year ago

yuqinzhou9 commented 1 year ago

Thank you for the insightful repo. Are there any plans to release the code for the experiments concerning linear vanilla RNNs? I understand the paper has mentioned that "All other settings in the recurrent block are kept the same as in the Vanilla RNN module of Haiku". However, it would be perfect if you could provide your implementation details.

NicolasZucchet commented 1 year ago

I don't plan to implement Linear RNNs in this repo in the near future.

However, here are some pointers on what to do to implement them within this repo. Create a RNN class in model.py, similar to the LRU one. The main differences would be that $\lambda$ is now a fully connected real matrix $A$, that $B$ and $C$ are real matrices and that one has to use a sequential scan (jax.lax.scan) instead of a parallel one (jax.lax.associative_scan). To match haiku weight initialization, one probably has to change the weight initialization from normal to truncated normal distribution (see https://github.com/deepmind/dm-haiku/blob/2e941023d70db39f59eaaf67c7b21d5c836cc290/haiku/_src/basic.py#L174). Apart from that, the rest of the code should stay the same.

yuqinzhou9 commented 1 year ago

Thank you for your kind reply!