Open Jaldrich2426 opened 2 days ago
Hi @Jaldrich2426 - which pipeline are you using? Yes, I would expect the jacobian wrt state to give meaningful results. Can you post a colab or code snippet?
Thanks @erikfrey! I'm working off a trimmed-down example of this notebook here. I've primarily been testing in the positional pipeline but tried using the others with no luck. I've also moved to using the observation instead of the pipeline state since it contains the pipeline's state's position and velocity (q and qd) in this example. Here's my trimmed version:
import functools
import jax
from jax import numpy as jp
from brax import envs
from brax.io import model
from brax.training.agents.ppo import train as ppo
env_name = 'ant'
backend = 'positional' # @param ['generalized', 'positional', 'spring']
env = envs.get_environment(env_name=env_name,
backend=backend)
state = env.reset(rng=jax.random.PRNGKey(seed=0))
train_fn = functools.partial(ppo.train, num_timesteps=50_000_000, num_evals=10, reward_scaling=10, episode_length=1000, normalize_observations=True, action_repeat=1,
unroll_length=5, num_minibatches=32, num_updates_per_batch=4, discounting=0.97, learning_rate=3e-4, entropy_cost=1e-2, num_envs=4096, batch_size=2048, seed=1)
make_inference_fn, params, _ = train_fn(environment=env)
inference_fn = make_inference_fn(params)
env = envs.create(env_name=env_name, backend=backend)
def trimmed_state_step(state, action):
new_state = env.step(state, action)
return new_state.obs
rng = jax.random.PRNGKey(seed=1)
state = env.reset(rng=rng)
act_rng, rng = jax.random.split(rng)
act, _ = inference_fn(state.obs, act_rng)
new_q = trimmed_state_step(state, act)
print(new_q)
print(jax.jacobian(trimmed_state_step,argnums=1)(state, act))
print(jax.jacobian(trimmed_state_step, argnums=0)(state, act).obs)
print(jp.linalg.norm(jax.jacobian(trimmed_state_step, argnums=0)(state, act).obs))
Hi, In a similar direction to #83, I'm attempting to use Brax for a form of optimal control, but I need access to the Jacobian of the dynamics step function with respect to both the input state and action. Specifically, if
I'd like the jacobian of
new_state
with respect to certain elements instate
, and the jacobian ofnew_state
with respect toaction
. If this were a linear model underI'd want
A
andB
. For the ant sample environment, I was able to compute a gradient with respect toaction
by defining a helper step function ofand calling
However, attempting to achieve similar results for relevant fields of the state (q in this example)
yields an all 0's result. Is there a better way to compute these jacobians, and is this supported behavior by Brax?