google-deepmind / optax

Optax is a gradient processing and optimization library for JAX.
https://optax.readthedocs.io
Apache License 2.0
1.68k stars 190 forks source link

Unsupported operand type for float and dict while updating params #808

Closed ArunDtej closed 8 months ago

ArunDtej commented 8 months ago

can someone help with making this work or tell me what's causing the error ?

TypeError: unsupported operand type(s) for *: 'float' and 'dict' while update of params,

def f(x): return 5*x -10 y_true = f(x) params = { 'weights': jnp.ones((3,3)) } def model(x): return x.dot(params['weights'])

def loss(y_true, x): y_pred = model(x) return np.mean((y_true - y_pred)**2)

print(loss(y_true,model(x)))

optimizer = optax.adam( learning_rate= 0.002 ) opt_state = optimizer.init(params) loss_value, grads = jax.value_and_grad(loss, allow_int = True)( y_true, x) updates, opt_state = optimizer.update(grads, opt_state) params = optax.apply_updates(params, updates)

vroulet commented 8 months ago

Hello @ArunDtej, You applied value_and_grad to y_true not to params. So the gradients are not of the shape of params. Hence the updates do not have the shape of params and so the error.

ArunDtej commented 8 months ago

Thank you @vroulet I should have referred to the documentation, sorry for the trouble :')