google / trajax

Apache License 2.0
186 stars 23 forks source link

Does not work w/ BRAX #6

Open MahanFathi opened 1 year ago

MahanFathi commented 1 year ago

Has anyone tried the solvers on BRAX environments? Here's what I have:

import trajax
import jax
from jax import numpy as jnp
from jax.flatten_util import ravel_pytree
import brax
from brax import envs

def get_f_and_c(env):
    key = jax.random.PRNGKey(0)
    state = env.reset(key)
    _, x2qp = ravel_pytree(state.qp)
    def f(x, u, t):
        qp = x2qp(x)
        nqp, _ = env.sys.step(qp, u)
        return ravel_pytree(nqp)[0]
    def c(x, u, t):
        qp = x2qp(x)
        dstate = state.replace(qp=qp)
        nstate = env.step(dstate, u)
        return -nstate.reward
    return f, c

env = envs.create('inverted_pendulum')
key = jax.random.PRNGKey(0)
state = env.reset(key)
x_init, x2qp = ravel_pytree(state.qp)

f, c = get_f_and_c(env)

x, u, cost, *outputs = trajax.optimizers.ilqr(c, f, x_init, jnp.zeros([1, env.action_size]))

which gives:

TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(float32[3])>with<DynamicJaxprTrace(level=1/0)>
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError