google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
6.01k stars 633 forks source link

better support for RNN layers required #2170

Open sourabh2k15 opened 2 years ago

sourabh2k15 commented 2 years ago

We have a usecase where we're implementing DeepSpeech2 model in Flax. DeepSpeech2 is an older speech recognition model based on RNN style layers (Bi-LSTMs used commonly)

flax doesn't have Bi-LSTMs so we hacked a version of our own based on existing RNNCell but I messed up handling of paddings and this caused a long debugging loop

eventually we found a flax BiLSTM layer folks implemented that flips sequences for the reverse direction to run LSTM and then flips the output which worked for our usecase involving padded inputs.

Overall it feels current RNN layers in flax are very bare bones as compared to pytorch which does RNNs really well, it'd be amazing to have full-fledged Bi-LSTM, GRU, RNN cells ready to go , currently folks would even have to write up their own wrapper that uses nn.scan around the default flax cell primitives for RNN

marcvanzee commented 2 years ago

@cgarciae who is working on better RNN support #2126

chiamp commented 11 months ago

hi @sourabh2k15, we have a Bidirectional class. Would that work for you?

module = nn.Bidirectional(nn.RNN(nn.GRUCell(5)), nn.RNN(nn.GRUCell(5)))
x = jnp.ones((7, 3))
v = module.init(jax.random.PRNGKey(0), x)
out = module.apply(v, x)