patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
2.12k stars 142 forks source link

Optax.Multisteps problems, callable set to None #896

Closed louisfabrice13 closed 1 week ago

louisfabrice13 commented 1 week ago

I am having issues with several equinox custom models and optax.MultiSteps. I have noticed problems with MultiSteps and adamw that disappeared with adam, but now the issue seems to be with MultiSteps alone. I have had different error messages that sprout basically out of nowhere, I cannot seem to track them to any particular feature. When using actual data and a 1 layer CNN, customized to receive a callable for activation function, the activation function is updated to None at the first training iteration.

    conv: eqx.nn.Conv3d
    activation: callable
    def __init__(
        self,
        activation
    ):
        self.conv = eqx.nn.Conv3d(kernel_size=1, in_channels=4, out_channels=3, stride=2, key=jr.PRNGKey(1))
        self.activation = activation
    def __call__(self,x,key=None):
        return self.activation(self.conv(x))

totrain_model = cnn(jax.nn.relu)```

I initialize the optimizer correctly with
`opt_state = optim.init(eqx.filter(totrain_model, eqx.is_array))`
the training step is
`totrain_model, opt_state, train_loss = train_step(totrain_model, jr.fold_in(epoch_key, batch_idx), opt_state, x, y)`

with `train_step = build_step_updater(optim)`

def build_step_updater(optimizer: optax.GradientTransformation): """Builds a function for executing a single step in the optimization.""" @eqx.filter_jit def make_step_update(model, key, opt_state, x, y): print("1") print(model) loss_value, grads = eqx.filter_value_and_grad(loss)(model, key, x, y) print("after loss") print(model) updates, opt_state = optimizer.update(grads, opt_state, eqx.filter(model, eqx.is_array)) print("after update of optim") print(model) model = eqx.apply_updates(eqx.filter(model, eqx.is_array), updates) print("after update of model") print(model) return model, opt_state, loss_value

return make_step_update

By printing the model I can confirm that `model = eqx.apply_updates(eqx.filter(model, eqx.is_array), updates)` is where the activation "parameter" changes from `<wrapped function relu>` to `None`.
This makes little sense and happens on bigger models.
I possibly have the issue with dropout of [772](https://github.com/patrick-kidger/equinox/issues/772), but it is not solved by the correct filtering and the issue with activation getting set to None remains.
The issue emerged without any changing of the environment or kernel I was working on, btw.
jax.__version__ 0.4.33
eqx.__version__ 0.11.7
optax.__version__ 0.2.3
patrick-kidger commented 1 week ago

Without having ran your code, I think you want to replace model = eqx.apply_updates(eqx.filter(model, eqx.is_array), updates) with model = eqx.apply_updates(model, updates).

What's happening here is that the model you have returned on this line does not have any of the non-arrays in the pytree, as you've just filtered them out!

louisfabrice13 commented 1 week ago

You are perfectly right, thanks for the timely intervention. I must have modified my code in the wrong way, while dealing with the Dropout layer, and got stuck with the wrong reparation on a different model.