google / brax

Massively parallel rigidbody physics simulation on accelerator hardware.
Apache License 2.0
2.38k stars 257 forks source link

Saving intermediate policies during the training #550

Closed mazzamani closed 2 weeks ago

mazzamani commented 4 weeks ago

My goal is to save intermediate policies. The policy at the end of training works fine for me:

        # ... main training code
        self.make_inference_fn, self.params, _ = self.train_fn(environment=env, progress_fn=self.progress_callback,
                                                               policy_params_fn=self.policy_params_callback)
        self.visualize_trajectory()

However, when I call a self.visualize_trajectory() from self.policy_params_callback

        def policy_params_callback(self, step, make_policy, params):
            self.make_inference_fn = make_policy
            self.params = params  
            self.visualize_trajectory()

I run into the following error:

  File "/brax/training_code.py", line 29, in policy_params_callback
    self.visualize_trajectory()
  File "/brax/training_code.py", line 74, in visualize_trajectory
    act, _ = jit_inference_fn(state.obs, act_rng)
  File "/miniconda3/envs/brax/lib/python3.10/site-packages/brax/training/agents/ppo/networks.py", line 44, in policy
    logits = policy_network.apply(*params, observations)
  File "/miniconda3/envs/brax/lib/python3.10/site-packages/brax/training/networks.py", line 104, in apply
    return policy_module.apply(policy_params, obs)
TypeError: argument of type 'PPONetworkParams' is not iterable

This is my visualization method which is adapted from the example training code:

    def visualize_trajectory(self):
        inference_fn = self.make_inference_fn(self.params)
        env = self.load_environment()
        jit_env_reset = jax.jit(env.reset)
        jit_env_step = jax.jit(env.step)
        jit_inference_fn = jax.jit(inference_fn)

        trajectory = []
        rng = jax.random.PRNGKey(seed=1)
        state = jit_env_reset(rng=rng)

        for _ in tqdm(range(1000)):
            trajectory.append(state.pipeline_state)
            act_rng, rng = jax.random.split(rng)
            act, _ = jit_inference_fn(state.obs, act_rng)
            state = jit_env_step(state, act)

        rendered_html = html.render(env.sys.tree_replace({'opt.timestep': env.dt}), trajectory)
        with open("trajectory_visualization.html", "w") as file:
            file.write(rendered_html)

        print("trajectory visualization prepared.")

Any idea why it is not working?

Edit: here is the whole script:

import functools

import jax
from brax import envs
from brax.io import html
from brax.training.agents.ppo import train as ppo

from tqdm import tqdm

class RLTrainer:
    def __init__(self):
        self.env_name = 'ant'
        self.backend = 'positional'
        self.params = None
        self.make_inference_fn = None
        self.train_fn = None

    def load_environment(self):
        env = envs.get_environment(env_name=self.env_name, backend=self.backend)
        return env

    def policy_params_callback(self, step, make_policy, params):
        self.make_inference_fn = make_policy
        self.params = params
        self.visualize_trajectory()

    def progress_callback(self, num_steps, metrics):
        print(f"Training progress: steps={num_steps}, reward={metrics['eval/episode_reward']:.2f}")

    def start_training(self, env_name, backend):
        """Begins the training in a separate thread."""
        print("Training started ...")
        self.env_name = env_name
        self.backend = backend
        self.stop_training = False
        env = self.load_environment()
        self.train_fn = {
            'ant': functools.partial(ppo.train, num_timesteps=50_000, num_evals=2, reward_scaling=10,
                                     episode_length=1000, normalize_observations=True, action_repeat=1,
                                     unroll_length=5, num_minibatches=32, num_updates_per_batch=4,
                                     discounting=0.97, learning_rate=3e-4, entropy_cost=1e-2, num_envs=4096,
                                     batch_size=2048, seed=1),
            # Add other environment setups as needed...
        }[self.env_name]

        self.make_inference_fn, self.params, _ = self.train_fn(environment=env, progress_fn=self.progress_callback,
                                                               policy_params_fn=self.policy_params_callback)

        print("training finished. Visualizing the policy...")
        self.visualize_trajectory()

    def visualize_trajectory(self):
        inference_fn = self.make_inference_fn(self.params)
        env = self.load_environment()
        jit_env_reset = jax.jit(env.reset)
        jit_env_step = jax.jit(env.step)
        jit_inference_fn = jax.jit(inference_fn)

        trajectory = []
        rng = jax.random.PRNGKey(seed=1)
        state = jit_env_reset(rng=rng)

        for _ in tqdm(range(1000)):
            trajectory.append(state.pipeline_state)
            act_rng, rng = jax.random.split(rng)
            act, _ = jit_inference_fn(state.obs, act_rng)
            state = jit_env_step(state, act)

        rendered_html = html.render(env.sys.tree_replace({'opt.timestep': env.dt}), trajectory)
        with open("trajectory_visualization.html", "w") as file:
            file.write(rendered_html)

        print("trajectory visualization prepared.")

# Run the visualizer
if __name__ == "__main__":
    rl_trainer = RLTrainer()
    rl_trainer.start_training(env_name='ant', backend='positional')
btaba commented 2 weeks ago

Hi, policy_module.apply expects the params to be an iterable, not PPONetworkParams. You can unpack params via (params.policy, params.value)

mazzamani commented 1 week ago

Thanks for the answer. So, this worked for me:

inference_fn = self.make_inference_fn((self.params[0], self.params[1].policy))