Has anyone tried the solvers on BRAX environments? Here's what I have:
import trajax
import jax
from jax import numpy as jnp
from jax.flatten_util import ravel_pytree
import brax
from brax import envs
def get_f_and_c(env):
key = jax.random.PRNGKey(0)
state = env.reset(key)
_, x2qp = ravel_pytree(state.qp)
def f(x, u, t):
qp = x2qp(x)
nqp, _ = env.sys.step(qp, u)
return ravel_pytree(nqp)[0]
def c(x, u, t):
qp = x2qp(x)
dstate = state.replace(qp=qp)
nstate = env.step(dstate, u)
return -nstate.reward
return f, c
env = envs.create('inverted_pendulum')
key = jax.random.PRNGKey(0)
state = env.reset(key)
x_init, x2qp = ravel_pytree(state.qp)
f, c = get_f_and_c(env)
x, u, cost, *outputs = trajax.optimizers.ilqr(c, f, x_init, jnp.zeros([1, env.action_size]))
which gives:
TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(float32[3])>with<DynamicJaxprTrace(level=1/0)>
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
Has anyone tried the solvers on BRAX environments? Here's what I have:
which gives: