google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
6.07k stars 642 forks source link

ToLinen is not hashable (Linen modules are) #4156

Open PhilipVinc opened 1 month ago

PhilipVinc commented 1 month ago

Linen modules are washable, so I would expect nnx.bridge.ToLinen to be as well.

In [1]: from flax import linen as nn, nnx

In [2]: import jax

In [3]: model = nnx.bridge.ToLinen(nnx.Linear, args=(32, 64))

In [4]: hash(model)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
File ~/Documents/pythonenvs/netket/python-3.11.2/lib/python3.11/site-packages/flax/linen/module.py:726, in _wrap_hash.<locals>.wrapped(self)
    725 try:
--> 726   hash_value = hash_fn(self)
    727 except TypeError as exc:

File <string>:3, in __hash__(self)

TypeError: unhashable type: 'dict'

The above exception was the direct cause of the following exception:

TypeError                                 Traceback (most recent call last)
Cell In[4], line 1
----> 1 hash(model)

File ~/Documents/pythonenvs/netket/python-3.11.2/lib/python3.11/site-packages/flax/linen/module.py:728, in _wrap_hash.<locals>.wrapped(self)
    726   hash_value = hash_fn(self)
    727 except TypeError as exc:
--> 728   raise TypeError(
    729     'Failed to hash Flax Module.  '
    730     'The module probably contains unhashable attributes.  '
    731     f'Module={self}'
    732   ) from exc
    733 return hash_value

TypeError: Failed to hash Flax Module.  The module probably contains unhashable attributes.  Module=ToLinen(
    # attributes
    nnx_class = Linear
    args = (32, 64)
    kwargs = {}
    skip_rng = False
)

This is problematic because I cannot use ToLinen in lieu of standard linen modules which I pass as static arguments jax jax.jit functions.

PhilipVinc commented 1 month ago

I just realised that the non-hashable element is the kwargs dictionary, as hash(dict()) fails.

Using for example flax.core.FrozenDict works correctly

>>> from flax.core import FrozenDict
>>> from flax import linen as nn, nnx
>>> model = nnx.bridge.ToLinen(nnx.Linear, args=(32, 64), kwargs=FrozenDict())
>>> hash(model)
5814856164823000827

Which makes sense Right now one could do something like

>>> model = nnx.bridge.ToLinen(nnx.Linear, args=(32, 64))
>>> model.kwargs['hello'] = 1

and potentially break some invariants in jax's caching.

I think the default in here should be some sort of frozen dictionary, and it would be reasonable to freeze the dictionary passed in to the ToLinen module. However I'm not sure how to achieve that latter point?

cgarciae commented 1 month ago

Thanks @PhilipVinc for reporting this! I've sent #4159 to try to address this.

PhilipVinc commented 3 weeks ago

Sorry to bump again, but #4159 has not been merged and it would be lovely to see this addressed.