google-deepmind / dm-haiku

JAX-based neural network library
https://dm-haiku.readthedocs.io
Apache License 2.0
2.91k stars 231 forks source link

Analog to `flax.struct.dataclass` #610

Closed homerjed closed 1 year ago

homerjed commented 1 year ago

Hi Haiku team,

Is there any potential for something like flax.struct.dataclass in haiku?

I am writing a small library that needs to train a variety of models in different specifications.

E.g. A user may need a flow trained with optimizer A, architecture B etc.

Ideally I would have the whole model architecture and training specification packaged in some kind of NamedTuple.

I have tried chex.dataclass, hk.data_structures.to_haiku_dict and NamedTuple variations but it (as far as I know) cannot make args static to be ignored by jax.jit.

Thank you!

tomhennigan commented 1 year ago

Hi @homerjed , I think this is out of scope for Haiku (this library is narrowly focussed on defining an neural net module abstraction and state management). Chex would be the right place in DeepMind's JAX ecosystem for this sort of utility.

Out of curiosity, is there a reason why Flax's dataclass implementation doens't work for you? It seems not coupled to their nn.Module implementation so should work just fine with Haiku transformed functions.

If you would prefer to not take an additional dep, then talking to the chex developers about enhancing their dataclass offering might be a good path forwards.

homerjed commented 1 year ago

Hi @tomhennigan, yeah I like haiku for its simplicity.

I actually never tried flax.struct.dataclass as I didn't want the additional dependency (non-trivial to extract from their library as well). I will write to the chex team, good idea.

I currently use simple-pytree if anyone else comes by this issue.