google / objax

Apache License 2.0
769 stars 77 forks source link

TypeError during gradient computation: type <class 'objax.variable.TrainVar'> is not a valid JAX type #260

Open cyugao opened 1 year ago

cyugao commented 1 year ago

I was trying to run a simple example but there are type issues when evaluating the gradients?

TypeError: Argument 'objax.TrainVar(Traced<ConcreteArray([-1.1010288  -0.6818452  -0.95236534], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([-1.1010288 , -0.6818452 , -0.95236534], dtype=float32)
  tangent = Traced<ShapedArray(float32[3])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[3]), None)
    recipe = LambdaBinding(), reduce=reduce_mean)' of type <class 'objax.variable.TrainVar'> is not a valid JAX type.

Minimal example from the docs:

import objax
import jax.numpy as jn

n = 1000
ndim = 10
X = objax.random.normal((n, ndim))
y = objax.random.normal((n, 1))
w = objax.TrainVar(jn.zeros(ndim))
b = objax.TrainVar(jn.zeros(1))

def loss(x, y):
    pred = jn.dot(x, w) + b
    return 0.5 * ((y - pred) ** 2).mean()

g_fn = objax.Grad(loss,           # g_fn is Objax module
                  objax.VarCollection({'w': w, 'b': b}))
g_value = g_fn(X, y)