google / brax

Massively parallel rigidbody physics simulation on accelerator hardware.
Apache License 2.0
2.14k stars 234 forks source link

Updated basic APG algorithm #476

Closed Andrew-Luo1 closed 2 months ago

Andrew-Luo1 commented 2 months ago

The goal of this proposed update is to provide a simple APG algorithm that can solve non-trivial tasks, as a first step for researchers and practitioners to explore Brax's differentiable simulation. It has been tested on MJX. Notes:

1: Algorithm Update

This fork contains an APG implementation that is about as simple as the current one, but reflects a common thread between several recent results that have used differentiable simulators to achieve locomotion: 1, 2, 3.

Brax's current APG algorithm is roughly equivalent to the following pseudocode:

for i in range(n_epochs):
    reset state
    policy_grads = []
    for j in range(episode_length // short_horizon))
        state, policy_grad = jax.grad(unroll(state, policy, short_horizon))
        policy_grads.append(policy_grad)
    optimizer.step(mean(policy_grads))

In contrast, the cited results update the policy gradient much more frequently, using the observation that policy gradients that differentiate through the simulator have low variance. Hence, unrolling for an entire episode before updating has limited use. That additional samples past a certain point do not help is seen in that convergence does not increase with with massive parallelization [2]. The proposed APG algorithm essentially performs live stochastic gradient descent on the policy, unrolling it for a short window, doing a gradient update, then continuing where it left off:

reset state
for i in range(n_epochs):
    state, policy_grad = jax.grad(unroll(state, policy, short_horizon))
    optimizer.step(policy_grad)

Note that n_epochs can be much larger in this case. This modification allows the algorithm to relatively quickly learn quadruped locomotion, albeit with a particular training pipeline and reward design. (Notebook)

Additional notes:

2: Supporting Updates

Configurable initial policy variance: When hotstarting a policy, it benefits to explore around its induced state-space vicinity. This can be done by initializing the policy network weights small. Currently, the softplus disables this possibility, so this fork adds a scaling parameter.

Layer norm: I have found that using layer normalization in the policy neural network has greatly improved the training stability of APG methods and is seen in other implementations.

google-cla[bot] commented 2 months ago

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

erikfrey commented 2 months ago

Oh - it looks like you need to update brax/training/agents/apg/train_test.py - changes should hopefully be minimal - please review the failing test.

Also, please do sign the CLA. Thank you!

Andrew-Luo1 commented 2 months ago

Hi @erikfrey, I have updated the tests and they pass on my local setup. I've also fixed the nit. I've signed the CLA, and the Checks tab is saying that my signing went through. Please let me know if there's anything missing.

Andrew-Luo1 commented 2 months ago

I removed the double precision toggle and the tests still run fine on my local setup. Let's see if this works :)

erikfrey commented 2 months ago

Amazing, thank you!