google / brax

Massively parallel rigidbody physics simulation on accelerator hardware.
Apache License 2.0
2.14k stars 234 forks source link

Jacobian of State Dynamics #496

Open Jaldrich2426 opened 2 days ago

Jaldrich2426 commented 2 days ago

Hi, In a similar direction to #83, I'm attempting to use Brax for a form of optimal control, but I need access to the Jacobian of the dynamics step function with respect to both the input state and action. Specifically, if

f(state,action) -> new_state

I'd like the jacobian of new_state with respect to certain elements in state, and the jacobian of new_state with respect to action. If this were a linear model under

x_next = Ax+Bu

I'd want A and B. For the ant sample environment, I was able to compute a gradient with respect to action by defining a helper step function of

def trimmed_state_step(state, action):
    new_state = env.step(state, action)
    return new_state.pipeline_state.q

and calling

jax.jacobian(trimmed_state_step,argnums=1)(state, act)

However, attempting to achieve similar results for relevant fields of the state (q in this example)

jax.jacobian(trimmed_state_step, argnums=0)(state, act).pipeline_state.q

yields an all 0's result. Is there a better way to compute these jacobians, and is this supported behavior by Brax?

erikfrey commented 1 day ago

Hi @Jaldrich2426 - which pipeline are you using? Yes, I would expect the jacobian wrt state to give meaningful results. Can you post a colab or code snippet?

Jaldrich2426 commented 1 day ago

Thanks @erikfrey! I'm working off a trimmed-down example of this notebook here. I've primarily been testing in the positional pipeline but tried using the others with no luck. I've also moved to using the observation instead of the pipeline state since it contains the pipeline's state's position and velocity (q and qd) in this example. Here's my trimmed version:

import functools
import jax
from jax import numpy as jp

from brax import envs
from import model
from import train as ppo

env_name = 'ant'
backend = 'positional'  # @param ['generalized', 'positional', 'spring']

env = envs.get_environment(env_name=env_name,

state = env.reset(rng=jax.random.PRNGKey(seed=0))

train_fn = functools.partial(ppo.train,  num_timesteps=50_000_000, num_evals=10, reward_scaling=10, episode_length=1000, normalize_observations=True, action_repeat=1,
                             unroll_length=5, num_minibatches=32, num_updates_per_batch=4, discounting=0.97, learning_rate=3e-4, entropy_cost=1e-2, num_envs=4096, batch_size=2048, seed=1)

make_inference_fn, params, _ = train_fn(environment=env)
inference_fn = make_inference_fn(params)

env = envs.create(env_name=env_name, backend=backend)

def trimmed_state_step(state, action):
    new_state = env.step(state, action)
    return new_state.obs

rng = jax.random.PRNGKey(seed=1)
state = env.reset(rng=rng)

act_rng, rng = jax.random.split(rng)
act, _ = inference_fn(state.obs, act_rng)

new_q = trimmed_state_step(state, act)
print(jax.jacobian(trimmed_state_step,argnums=1)(state, act))
print(jax.jacobian(trimmed_state_step, argnums=0)(state, act).obs)
print(jp.linalg.norm(jax.jacobian(trimmed_state_step, argnums=0)(state, act).obs))