google-deepmind / dm-haiku

JAX-based neural network library
https://dm-haiku.readthedocs.io
Apache License 2.0
2.91k stars 231 forks source link

[JAX] Fix up uses of PyTree and PyTreeDef types. #615

Closed copybara-service[bot] closed 1 year ago

copybara-service[bot] commented 1 year ago

[JAX] Fix up uses of PyTree and PyTreeDef types.

In many cases, PyTreeDef was used as a type when pytree (Any) should have been used.