cgarciae / treex

A Pytree Module system for Deep Learning in JAX
https://cgarciae.github.io/treex/
MIT License
215 stars 17 forks source link

WIP: Implementation of Recurrent layers #45

Closed ptigwe closed 2 years ago

ptigwe commented 2 years ago

Initial implementation of recurrent layers starting with GRUCell which ports the corresponding implementation from flax.nn.recurrent.GRUCell.

The API is still a WIP and not fixed yet, but open to discussion as development goes on.

At the moment, the GRUCell allows for initialization of the starting hidden state using either:

cgarciae commented 2 years ago

Hey @ptigwe! Thanks for taking the initiative.

The first general comment I have is that this reimplements GRU core logic from scratch, we have avoided so far for the following reason: Flax is well supported and Treex could leverage Flax improvements as they come for free, reducing maintenance cost. If you want some inspiration take a look at BatchNorm.

ptigwe commented 2 years ago

Was in a bit of a rush and seem to have messed up my git history. Would redo this in a bit.