Closed homerjed closed 1 year ago
Hi @homerjed , I think this is out of scope for Haiku (this library is narrowly focussed on defining an neural net module abstraction and state management). Chex would be the right place in DeepMind's JAX ecosystem for this sort of utility.
Out of curiosity, is there a reason why Flax's dataclass
implementation doens't work for you? It seems not coupled to their nn.Module
implementation so should work just fine with Haiku transformed functions.
If you would prefer to not take an additional dep, then talking to the chex
developers about enhancing their dataclass
offering might be a good path forwards.
Hi @tomhennigan, yeah I like haiku
for its simplicity.
I actually never tried flax.struct.dataclass
as I didn't want the additional dependency (non-trivial to extract from their library as well). I will write to the chex
team, good idea.
I currently use simple-pytree if anyone else comes by this issue.
Hi Haiku team,
Is there any potential for something like
flax.struct.dataclass
inhaiku
?I am writing a small library that needs to train a variety of models in different specifications.
E.g. A user may need a flow trained with optimizer A, architecture B etc.
Ideally I would have the whole model architecture and training specification packaged in some kind of
NamedTuple
.I have tried
chex.dataclass
,hk.data_structures.to_haiku_dict
andNamedTuple
variations but it (as far as I know) cannot make args static to be ignored byjax.jit
.Thank you!