Closed louisfabrice13 closed 1 week ago
Without having ran your code, I think you want to replace model = eqx.apply_updates(eqx.filter(model, eqx.is_array), updates)
with model = eqx.apply_updates(model, updates)
.
What's happening here is that the model
you have returned on this line does not have any of the non-arrays in the pytree, as you've just filtered them out!
You are perfectly right, thanks for the timely intervention. I must have modified my code in the wrong way, while dealing with the Dropout layer, and got stuck with the wrong reparation on a different model.
I am having issues with several equinox custom models and optax.MultiSteps. I have noticed problems with MultiSteps and adamw that disappeared with adam, but now the issue seems to be with MultiSteps alone. I have had different error messages that sprout basically out of nowhere, I cannot seem to track them to any particular feature. When using actual data and a 1 layer CNN, customized to receive a callable for activation function, the activation function is updated to None at the first training iteration.
def build_step_updater(optimizer: optax.GradientTransformation): """Builds a function for executing a single step in the optimization.""" @eqx.filter_jit def make_step_update(model, key, opt_state, x, y): print("1") print(model) loss_value, grads = eqx.filter_value_and_grad(loss)(model, key, x, y) print("after loss") print(model) updates, opt_state = optimizer.update(grads, opt_state, eqx.filter(model, eqx.is_array)) print("after update of optim") print(model) model = eqx.apply_updates(eqx.filter(model, eqx.is_array), updates) print("after update of model") print(model) return model, opt_state, loss_value