Open PhilipVinc opened 1 month ago
I just realised that the non-hashable element is the kwargs
dictionary, as hash(dict())
fails.
Using for example flax.core.FrozenDict
works correctly
>>> from flax.core import FrozenDict
>>> from flax import linen as nn, nnx
>>> model = nnx.bridge.ToLinen(nnx.Linear, args=(32, 64), kwargs=FrozenDict())
>>> hash(model)
5814856164823000827
Which makes sense Right now one could do something like
>>> model = nnx.bridge.ToLinen(nnx.Linear, args=(32, 64))
>>> model.kwargs['hello'] = 1
and potentially break some invariants in jax's caching.
I think the default in here should be some sort of frozen dictionary, and it would be reasonable to freeze the dictionary passed in to the ToLinen
module. However I'm not sure how to achieve that latter point?
Thanks @PhilipVinc for reporting this! I've sent #4159 to try to address this.
Sorry to bump again, but #4159 has not been merged and it would be lovely to see this addressed.
Linen modules are washable, so I would expect
nnx.bridge.ToLinen
to be as well.This is problematic because I cannot use
ToLinen
in lieu of standard linen modules which I pass as static arguments jax jax.jit functions.