nicklashansen / tdmpc2

Code for "TD-MPC2: Scalable, Robust World Models for Continuous Control"
https://www.tdmpc2.com
MIT License
272 stars 49 forks source link

Reward hyperparameters vs original paper #21

Closed ShaneFlandermeyer closed 3 months ago

ShaneFlandermeyer commented 3 months ago

Hello,

In the config file in this repository, the reward scale factors are consistency/reward/value = 20/0.1/0.1, but in the paper you use 10/0.5/1.0. Is there a particular reason for this change? Have you found that the new configuration works better across a wider range of tasks?

My intuition is that giving a higher weight to the consistency loss should get the agent to a "good" latent representation more quickly, but I would greatly appreciate your thought process.

nicklashansen commented 3 months ago

Good catch! This is a typo in the paper. The correct values to use (which we also used in our experiments) are the default values in the codebase, i.e., 20/0.1/0.1. I would not expect the results to differ that much between the two configurations though. In general, I believe that whether a larger consistency loss coef is beneficial or not is somewhat task dependent, but this particular choice of coefs worked well enough in all domains that we tried. I will update the paper with the corrected values. Thank you!

ShaneFlandermeyer commented 3 months ago

EDIT: The more I edit this comment, the more I believe it's actually worth its own issue. Lmk if I should make a new one!

Awesome! I have one more unrelated question that isn't worth its own issue:

Have you tested the agent on the Humanoid-v4 or similar gym/mujoco environments? My jax implementation and your implementation both do well on HalfCheetah-v4, but do not exceed a reward of 300 in Humanoid. I'm trying to determine if it's an agent hyperparameter problem or something about the environment. Some weird quirks about the environment I think could make an impact:

For reference, the CleanRL SAC agent gets rewards >4000 in ~500k timesteps. The only change I'm making in this repo is in make_env as below.

def make_env(cfg):
    """
    Make an environment for TD-MPC2 experiments.
    """
    gym.logger.set_level(40)
    if cfg.multitask:
        env = make_multitask_env(cfg)

    else:
        # env = None
        # for fn in [make_dm_control_env, make_maniskill_env, make_metaworld_env, make_myosuite_env]:
        #   try:
        #       env = fn(cfg)
        #   except ValueError:
        #       pass
        # if env is None:
        #   raise ValueError(f'Failed to make environment "{cfg.task}": please verify that dependencies are installed and that the task exists.')
        env = gym.make('Humanoid-v4')
        env = TensorWrapper(env)
    if cfg.get('obs', 'state') == 'rgb':
        env = PixelWrapper(cfg, env)
    try: # Dict
        cfg.obs_shape = {k: v.shape for k, v in env.observation_space.spaces.items()}
    except: # Box
        cfg.obs_shape = {cfg.get('obs', 'state'): env.observation_space.shape}
    cfg.action_dim = env.action_space.shape[0]
    # cfg.episode_length = env.max_episode_steps
    cfg.episode_length = 500
    cfg.seed_steps = max(1000, 5*cfg.episode_length)
    return env
ShaneFlandermeyer commented 3 months ago

A minor update: I have narrowed the problem above down to the plan step. It works when actions are just sampled directly.

ShaneFlandermeyer commented 3 months ago

What do you mean by "actions are just sampled directly", just set mpc=false in config.yaml?

Right. Just sampling directly from the policy distribution.

To make it work without planning, I also added termination in the TD target computation to set the discount to zero in terminal states. I'm currently experimenting with a learned continuation flag that is applied in the _estimate_value during planning. Something like this:

    for t in range(self.horizon):
      reward = self.model.reward(z, actions[t])
      continues = self.model.continuation(z)
      z = self.model.next(z, actions[t])
      G += discount * reward
      discount *= self.discount * continues

I have contemplated that in certain long-range correlated tasks, planning might lead to side effects because you would be estimating Q-values based on an inaccurate final state, while intermediate planned states do not contribute positively to value estimation.

I think some recurrence in the world model would help a lot with this. Dreamer's world model does well over some pretty long time horizons, for instance.

ShaneFlandermeyer commented 3 months ago

I played around with this a bit more yesterday, and I'm getting great results for Humanoid now. The following changes helped. The code snippets are in Jax, but hopefully it's still readable enough:

    G, discount = 0, 1
    for t in range(self.horizon):
      reward, _ = self.model.reward(
          z, actions[t], self.model.reward_model.params)
      z = self.model.next(z, actions[t], self.model.dynamics_model.params)
      G += discount * reward

      if self.model.predict_continues:
        continues = jax.nn.sigmoid(self.model.continue_model.apply_fn(
            {'params': self.model.continue_model.params}, z)).squeeze(-1)
      else:
        continues = 1

      discount *= self.discount * continues
      next_z = sg(self.model.encode(next_observations, encoder_params))
      td_targets = self.td_target(next_z, rewards, dones, key=target_dropout)

      # Latent rollout (compute latent dynamics + consistency loss)
      zs = jnp.empty((self.horizon+1, self.batch_size, next_z.shape[-1]))
      z = self.model.encode(observations[0], encoder_params)
      zs = zs.at[0].set(z)
      consistency_loss = jnp.zeros(self.batch_size)
      discount = jnp.ones(self.batch_size)
      horizon = jnp.zeros(self.batch_size)
      for t in range(self.horizon):
        z = self.model.next(z, actions[t], dynamics_params)
        consistency_loss += jnp.mean(
            (z - next_z[t])**2 * discount[:, None], -1)
        zs = zs.at[t+1].set(z)

        horizon += (discount > 0)
        discount *= self.rho * (1 - dones[t])

      # Get logits for loss computations
      _, q_logits = self.model.Q(
          zs[:-1], actions, value_params, value_dropout_key1)
      _, reward_logits = self.model.reward(zs[:-1], actions, reward_params)
      if self.model.predict_continues:
        continue_logits = self.model.continue_model.apply_fn(
            {'params': continue_params}, zs[1:]).squeeze(-1)

      reward_loss = jnp.zeros(self.batch_size)
      value_loss = jnp.zeros(self.batch_size)
      continue_loss = jnp.zeros(self.batch_size)
      discount = jnp.ones(self.batch_size)
      for t in range(self.horizon):
        reward_loss += soft_crossentropy(reward_logits[t], rewards[t],
                                         self.model.symlog_min,
                                         self.model.symlog_max,
                                         self.model.num_bins) * discount

        if self.model.predict_continues:
          continue_loss += optax.sigmoid_binary_cross_entropy(
              continue_logits[t], 1 - dones[t]) * discount

        for q in range(self.model.num_value_nets):
          value_loss += soft_crossentropy(q_logits[q, t], td_targets[t],
                                          self.model.symlog_min,
                                          self.model.symlog_max,
                                          self.model.num_bins) * discount

        discount *= self.rho * (1 - dones[t])

      consistency_loss = (consistency_loss / horizon).mean()
      reward_loss = (reward_loss / horizon).mean()
      value_loss = (value_loss / (horizon + self.model.num_value_nets)).mean()
      continue_loss = (continue_loss / horizon).mean()
      total_loss = (
          self.consistency_coef * consistency_loss +
          self.reward_coef * reward_loss +
          self.value_coef * value_loss +
          self.continue_coef * continue_loss
      )

I still need to do some more testing and make sure I didn't accidentally break anything in the infinite horizon case, but these results are extremely promising! I'll hopefully push them to the jax repo some time this weekend.

nicklashansen commented 3 months ago

@ShaneFlandermeyer this looks really interesting! this is similar to how i deal with early termination in the episodic-rl branch, except that i use the binarized signal during planning. i will definitely give the soft version that you're proposing a try, thanks for the suggestion!

snippets from the relevant parts:

https://github.com/nicklashansen/tdmpc2/blob/ff02f41e73cce8b6ef9eef99c3830e131fbaf97f/tdmpc2/tdmpc2.py#L95-L107

and

https://github.com/nicklashansen/tdmpc2/blob/ff02f41e73cce8b6ef9eef99c3830e131fbaf97f/tdmpc2/tdmpc2.py#L265

ShaneFlandermeyer commented 3 months ago

Ha, nice. I didn't know you had already looked into this! I just verified and pushed all my changes. Hopefully they can be useful. Gonna close this issue now.