Current code does not work for newer versions of JAX (and its ecosystem).
Luckly, the major culprit of the incompatibility here (brax) has a v1 module that makes it backwards compatible.
This is not the final solution (as a rewrite would probably be better), but at least now people can run this on the current versions (and enjoy all the benefits that new jax has)
Current code does not work for newer versions of JAX (and its ecosystem). Luckly, the major culprit of the incompatibility here (brax) has a
v1
module that makes it backwards compatible.This is not the final solution (as a rewrite would probably be better), but at least now people can run this on the current versions (and enjoy all the benefits that new jax has)
Great work BTW.
closes #6, closes #5