Open cyugao opened 1 year ago
I was trying to run a simple example but there are type issues when evaluating the gradients?
TypeError: Argument 'objax.TrainVar(Traced<ConcreteArray([-1.1010288 -0.6818452 -0.95236534], dtype=float32)>with<JVPTrace(level=2/0)> with primal = Array([-1.1010288 , -0.6818452 , -0.95236534], dtype=float32) tangent = Traced<ShapedArray(float32[3])>with<JaxprTrace(level=1/0)> with pval = (ShapedArray(float32[3]), None) recipe = LambdaBinding(), reduce=reduce_mean)' of type <class 'objax.variable.TrainVar'> is not a valid JAX type.
Minimal example from the docs:
import objax import jax.numpy as jn n = 1000 ndim = 10 X = objax.random.normal((n, ndim)) y = objax.random.normal((n, 1)) w = objax.TrainVar(jn.zeros(ndim)) b = objax.TrainVar(jn.zeros(1)) def loss(x, y): pred = jn.dot(x, w) + b return 0.5 * ((y - pred) ** 2).mean() g_fn = objax.Grad(loss, # g_fn is Objax module objax.VarCollection({'w': w, 'b': b})) g_value = g_fn(X, y)
I was trying to run a simple example but there are type issues when evaluating the gradients?
Minimal example from the docs: