patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
2.02k stars 135 forks source link

better bug hint when writing a simple neural network in equinox #788

Open zhengqigao opened 1 month ago

zhengqigao commented 1 month ago

Hi,

Thanks for the nice package. I am new to equinox. I attempted to write a simple MLP but failed with an error. From the returned information, I am a bit confused on how I should revise my code.

import jax
import jax.numpy as jnp
import equinox as eqx

class MLPeqx(eqx.Module):
    def __init__(self, hidden_dims):
        super().__init__()
        tmp_key = jax.random.split(jax.random.PRNGKey(0), len(hidden_dims) - 1)
        self.layers = [eqx.nn.Linear(hidden_dims[i], hidden_dims[i + 1], key=tmp_key[i]) for i in range(len(hidden_dims) - 1)]
        self.activation = jax.nn.relu

    def __call__(self, x):
        for i in range(len(self.layers) - 1):
            x = self.activation(self.layers[i](x))
        x = self.layers[-1](x)
        return x

MLP = MLPeqx(hidden_dims=[1,2,4,4,2,1])

The error I got:

Traceback (most recent call last):
  File "xxxxxx/misc/test1.py", line 18, in <module>
    MLP = MLPeqx(hidden_dims=[1,2,4,4,2,1])
  File "xxxxxx/python3.9/site-packages/equinox/_module.py", line 548, in __call__
    self = super(_ModuleMeta, initable_cls).__call__(*args, **kwargs)
  File "xxxxxx/python3.9/site-packages/equinox/_better_abstract.py", line 226, in __call__
    self = super().__call__(*args, **kwargs)
  File "xxxxxx/python3.9/site-packages/equinox/_module.py", line 376, in __init__
    init(self, *args, **kwargs)
  File "xxxxxx/misc/test1.py", line 9, in __init__
    self.layers = [eqx.nn.Linear(hidden_dims[i], hidden_dims[i + 1], key=tmp_key[i]) for i in range(len(hidden_dims) - 1)]
  File "xxxxxx/python3.9/site-packages/equinox/_module.py", line 811, in __setattr__
    raise AttributeError(f"Cannot set attribute {name}")
AttributeError: Cannot set attribute layers

What did I miss?

lockwo commented 1 month ago

Equinox modules are data classes (https://docs.python.org/3/library/dataclasses.html), so you have to specify the attributes in the class header, see https://docs.kidger.site/equinox/ for example.