cgarciae / treex

A Pytree Module system for Deep Learning in JAX
https://cgarciae.github.io/treex/
MIT License
214 stars 17 forks source link

Bumps `flax` to `0.4.0` #60

Closed ptigwe closed 2 years ago

ptigwe commented 2 years ago

Updates flax to the most recent version. This currently breaks the current implementation and way in which rng keys of dropout is being handled.

Currently have disabled one of the dropout equivalence tests as I am not fully aware if there is a method of directly affecting the value of next_key within a treex module.

cgarciae commented 2 years ago

Hey @ptigwe! Thanks a lot for this :)

Currently have disabled one of the dropout equivalence tests as I am not fully aware if there is a method of directly affecting the value of next_key within a treex module.

Yeah, trying to simulate the rng key splits is tricky (I've fought with it in the past). Managed to figure it out though so I pushed the fix + also applied pre-commit which was failing.

I think this is good to go once CI passes.