google / brax

Massively parallel rigidbody physics simulation on accelerator hardware.
Apache License 2.0
2.14k stars 234 forks source link

Scaling control actions for BRAX environments #473

Closed nic-barbara closed 1 month ago

nic-barbara commented 3 months ago

This PR is an attempt to fix https://github.com/google/brax/issues/472. All environments assume the control action is restricted to [-1,1].

For the few environments that have an action space that is not [-1,1] (humanoid, humanoidstandup, pusher, inverted_pendulum), the input action is linearly scaled to the limits of the action space inside the environment's step() function.

Please let me know if this is an appropriate solution, I'd be happy to iterate it. Note also that scaling the actions will mean that the "best" hyperparameters for these environments will likely change. Does this need to be addressed anywhere?

google-cla[bot] commented 3 months ago

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

nic-barbara commented 1 month ago

@btaba just a heads up I haven't forgotten about this and will run ablation studies soon! All my GPUs are currently in use for some other research.

nic-barbara commented 1 month ago

@btaba I've just run ablation studies on the humanoid and inverted_pendulum environments.

Results are attached. Action scaling doesn't seem to harm performance in either case. humanoid_results_scaled_v2.pdf humanoid_results_v2.pdf inverted_pendulum_results_scaled_v2.pdf inverted_pendulum_results_v2.pdf

Would you like me to run them for humanoidstandup and pusher too?

For reference, here's the code I used to generate these results. It's basically exactly the same as the Brax training notebook but with some additional data logging/plotting.

import functools
import matplotlib.pyplot as plt
import numpy as np

from brax import envs
from brax.io import model
from brax.training.agents.ppo import train as ppo
from datetime import datetime
from pathlib import Path

# suffix = ""
suffix = "_scaled"
env_names = ["humanoid"] # ["inverted_pendulum", "humanoid", "pusher", "humanoidstandup"]
seeds = list(range(3))

def _get_fname(env_name, seed):
    dirpath = Path(__file__).resolve().parent
    fname = dirpath / f"{env_name}_results{suffix}_v{seed}"
    return fname

def train_model(env_name, seed, backend='positional'):

    # Environments and training functions from brax tutorial
    env = envs.get_environment(env_name=env_name, backend=backend)    
    train_fn = {
    'inverted_pendulum': functools.partial(ppo.train, num_timesteps=2_000_000, num_evals=20, reward_scaling=10, episode_length=1000, normalize_observations=True, action_repeat=1, unroll_length=5, num_minibatches=32, num_updates_per_batch=4, discounting=0.97, learning_rate=3e-4, entropy_cost=1e-2, num_envs=2048, batch_size=1024, seed=seed),
    'humanoid': functools.partial(ppo.train,  num_timesteps=50_000_000, num_evals=10, reward_scaling=0.1, episode_length=1000, normalize_observations=True, action_repeat=1, unroll_length=10, num_minibatches=32, num_updates_per_batch=8, discounting=0.97, learning_rate=3e-4, entropy_cost=1e-3, num_envs=2048, batch_size=1024, seed=seed),
    'humanoidstandup': functools.partial(ppo.train, num_timesteps=100_000_000, num_evals=20, reward_scaling=0.1, episode_length=1000, normalize_observations=True, action_repeat=1, unroll_length=15, num_minibatches=32, num_updates_per_batch=8, discounting=0.97, learning_rate=6e-4, entropy_cost=1e-2, num_envs=2048, batch_size=1024, seed=seed),
    'pusher': functools.partial(ppo.train, num_timesteps=50_000_000, num_evals=20, reward_scaling=5, episode_length=1000, normalize_observations=True, action_repeat=1, unroll_length=30, num_minibatches=16, num_updates_per_batch=8, discounting=0.95, learning_rate=3e-4,entropy_cost=1e-2, num_envs=2048, batch_size=512, seed=seed),
    }[env_name]

    # Log rewards and print if required
    results = {"rewards": [], "stdev": [], "steps": [], "times": []}
    results["times"].append(datetime.now())

    def progress(num_steps, metrics):
        results["times"].append(datetime.now())
        results["steps"].append(num_steps)
        results["rewards"].append(metrics["eval/episode_reward"])
        results["stdev"].append(metrics["eval/episode_reward_std"])

        print("step: {} \t reward: {:.2f} \t stdev: {:.2f} \t time: {}".format(
            num_steps, 
            metrics["eval/episode_reward"], 
            metrics["eval/episode_reward_std"],
            results["times"][-1],))

    # Train
    _, params, _ = train_fn(environment=env, progress_fn=progress)
    times = results["times"]
    print(f'time to jit: {times[1] - times[0]}')
    print(f'time to train: {times[-1] - times[1]}')

    # Save params and metrics
    data = (*params, results)
    fname = _get_fname(env_name, seed)
    model.save_params(fname, data)

def plot_rewards(env_name, seeds):

    # Read in results
    xs, ys = [], []
    for seed in seeds:
        fname = _get_fname(env_name, seed)
        results = model.load_params(fname)[2]

        xs.append(np.array(results["steps"]))
        ys.append(np.array(results["rewards"]))
    x = np.vstack(xs).T
    y = np.vstack(ys).T

    # Plot formatting
    max_y = {'inverted_pendulum': 1100, 'humanoid': 13000, 'humanoidstandup': 75_000, 'pusher': 0}[env_name]
    min_y = {'reacher': -100, 'pusher': -150}.get(env_name, 0)

    # Make the plot and save it
    _, ax = plt.subplots()
    ax.plot(x, y)
    ax.set_ylim(min_y, max_y)
    ax.set_xlabel("Environment steps")
    ax.set_ylabel("Reward")
    ax.set_title(f"{env_name}{suffix} ({len(seeds)} random seeds)")
    plt.tight_layout()
    plt.savefig(f"{fname}.pdf")
    plt.close()

for env_name in env_names:
    for seed in seeds:
        train_model(env_name, seed)
    plot_rewards(env_name, seeds)
btaba commented 1 month ago

Hey @nic-barbara thanks for checking on those two envs, scaled plots LGTM!