Closed ArunDtej closed 8 months ago
Hello @ArunDtej, You applied value_and_grad to y_true not to params. So the gradients are not of the shape of params. Hence the updates do not have the shape of params and so the error.
Thank you @vroulet I should have referred to the documentation, sorry for the trouble :')
can someone help with making this work or tell me what's causing the error ?
TypeError: unsupported operand type(s) for *: 'float' and 'dict' while update of params,
def f(x): return 5*x -10 y_true = f(x) params = { 'weights': jnp.ones((3,3)) } def model(x): return x.dot(params['weights'])
def loss(y_true, x): y_pred = model(x) return np.mean((y_true - y_pred)**2)
print(loss(y_true,model(x)))
optimizer = optax.adam( learning_rate= 0.002 ) opt_state = optimizer.init(params) loss_value, grads = jax.value_and_grad(loss, allow_int = True)( y_true, x) updates, opt_state = optimizer.update(grads, opt_state) params = optax.apply_updates(params, updates)