cgarciae / treex

A Pytree Module system for Deep Learning in JAX
https://cgarciae.github.io/treex/
MIT License
215 stars 17 forks source link

Initial implementation of GRU layers #48

Closed ptigwe closed 2 years ago

ptigwe commented 2 years ago

Currently this shows a working implementation of GRU layer which follows quite closely the Keras API, using the flax.linen.GRUCell as its backbone.

Tackling the implementation of GRU (#40)

cgarciae commented 2 years ago

Thanks @ptigwe for this! Sorry it took so long to respond, I'd seen it but hadn't had the time to review it. Overall it looks very good, I'll just leave a couple of comments.

cgarciae commented 2 years ago

As discussed offline, we feel here we can lift Keras restrictions regarding the shape of the input by removing the time_major argument in favor of something like a time_axis which lets you select the dimension you are going to scan over. This changes requires a couple of changes to the logic including figuring out the proper batch_dims to initialize_carry.

cgarciae commented 2 years ago

Hey @ptigwe! There are a few minor comments / changes left if you'd like to finish them else I'd be very happy to continue with the PR.

ptigwe commented 2 years ago

@cgarciae, I believe I have fixed most of the comments which you mentioned before, including the time_axis as discussed. The only thing I didn't update was the switch from jax.lax.scan to jax_utils.scan_in_dim, as I was waiting for the change in flax to get merged and the upstream version updated.

Please let me know if there is anything else which I missed and I would make the changes ASAP.

cgarciae commented 2 years ago

Thanks @ptigwe ! There is a small comment about changing the default time_axis from 0 to -2.

ptigwe commented 2 years ago

@cgarciae, I guess I must have missed that one. I've added it now to the PR. Quick clarification as I'm fixing the errors, by having time_axis = -2 this means the default expected input shape would be [..., T, C] as opposed to the current default of 0 / -3 with [..., T, B, C].

cgarciae commented 2 years ago

Keras uses [B, T, C], we are going for [..., B, T, C] which is even better. I think this is nicer than [...,T, B, C] because 1D convolutions and transformers also use [..., B, T, C] so the op would chain better.

ptigwe commented 2 years ago

OK cool. It has already been set in the previous commit. Also decided not to include the change to jax_utils.scan_in_dim as the current version being pointed to 0.3.6 does not have the unrolled added to it yet.

cgarciae commented 2 years ago

If you want to update flax to the latest version to add this feature now you can run:

poetry add flax@latest
ptigwe commented 2 years ago

Seems the latest is indeed 0.3.6 which indeed does not have the updated scan_in_dims. Whenever it gets updated we can always update the code to make use of it. I've also included some comments on things that might need changing in that case.

cgarciae commented 2 years ago

@ptigwe Sounds good. I'll merge this for now, we can create a new PR with scan_in_dims latter.

Thanks a lot for pushing this through!