Closed ShaneFlandermeyer closed 3 months ago
Good catch! This is a typo in the paper. The correct values to use (which we also used in our experiments) are the default values in the codebase, i.e., 20/0.1/0.1. I would not expect the results to differ that much between the two configurations though. In general, I believe that whether a larger consistency loss coef is beneficial or not is somewhat task dependent, but this particular choice of coefs worked well enough in all domains that we tried. I will update the paper with the corrected values. Thank you!
EDIT: The more I edit this comment, the more I believe it's actually worth its own issue. Lmk if I should make a new one!
Awesome! I have one more unrelated question that isn't worth its own issue:
Have you tested the agent on the Humanoid-v4 or similar gym/mujoco environments? My jax implementation and your implementation both do well on HalfCheetah-v4, but do not exceed a reward of 300 in Humanoid. I'm trying to determine if it's an agent hyperparameter problem or something about the environment. Some weird quirks about the environment I think could make an impact:
For reference, the CleanRL SAC agent gets rewards >4000 in ~500k timesteps. The only change I'm making in this repo is in make_env as below.
def make_env(cfg):
"""
Make an environment for TD-MPC2 experiments.
"""
gym.logger.set_level(40)
if cfg.multitask:
env = make_multitask_env(cfg)
else:
# env = None
# for fn in [make_dm_control_env, make_maniskill_env, make_metaworld_env, make_myosuite_env]:
# try:
# env = fn(cfg)
# except ValueError:
# pass
# if env is None:
# raise ValueError(f'Failed to make environment "{cfg.task}": please verify that dependencies are installed and that the task exists.')
env = gym.make('Humanoid-v4')
env = TensorWrapper(env)
if cfg.get('obs', 'state') == 'rgb':
env = PixelWrapper(cfg, env)
try: # Dict
cfg.obs_shape = {k: v.shape for k, v in env.observation_space.spaces.items()}
except: # Box
cfg.obs_shape = {cfg.get('obs', 'state'): env.observation_space.shape}
cfg.action_dim = env.action_space.shape[0]
# cfg.episode_length = env.max_episode_steps
cfg.episode_length = 500
cfg.seed_steps = max(1000, 5*cfg.episode_length)
return env
A minor update: I have narrowed the problem above down to the plan step. It works when actions are just sampled directly.
What do you mean by "actions are just sampled directly", just set mpc=false in config.yaml?
Right. Just sampling directly from the policy distribution.
To make it work without planning, I also added termination in the TD target computation to set the discount to zero in terminal states. I'm currently experimenting with a learned continuation flag that is applied in the _estimate_value during planning. Something like this:
for t in range(self.horizon):
reward = self.model.reward(z, actions[t])
continues = self.model.continuation(z)
z = self.model.next(z, actions[t])
G += discount * reward
discount *= self.discount * continues
I have contemplated that in certain long-range correlated tasks, planning might lead to side effects because you would be estimating Q-values based on an inaccurate final state, while intermediate planned states do not contribute positively to value estimation.
I think some recurrence in the world model would help a lot with this. Dreamer's world model does well over some pretty long time horizons, for instance.
I played around with this a bit more yesterday, and I'm getting great results for Humanoid now. The following changes helped. The code snippets are in Jax, but hopefully it's still readable enough:
reward + (1 - done) * self.discount * Q
G, discount = 0, 1
for t in range(self.horizon):
reward, _ = self.model.reward(
z, actions[t], self.model.reward_model.params)
z = self.model.next(z, actions[t], self.model.dynamics_model.params)
G += discount * reward
if self.model.predict_continues:
continues = jax.nn.sigmoid(self.model.continue_model.apply_fn(
{'params': self.model.continue_model.params}, z)).squeeze(-1)
else:
continues = 1
discount *= self.discount * continues
discount
variable. next_z = sg(self.model.encode(next_observations, encoder_params))
td_targets = self.td_target(next_z, rewards, dones, key=target_dropout)
# Latent rollout (compute latent dynamics + consistency loss)
zs = jnp.empty((self.horizon+1, self.batch_size, next_z.shape[-1]))
z = self.model.encode(observations[0], encoder_params)
zs = zs.at[0].set(z)
consistency_loss = jnp.zeros(self.batch_size)
discount = jnp.ones(self.batch_size)
horizon = jnp.zeros(self.batch_size)
for t in range(self.horizon):
z = self.model.next(z, actions[t], dynamics_params)
consistency_loss += jnp.mean(
(z - next_z[t])**2 * discount[:, None], -1)
zs = zs.at[t+1].set(z)
horizon += (discount > 0)
discount *= self.rho * (1 - dones[t])
# Get logits for loss computations
_, q_logits = self.model.Q(
zs[:-1], actions, value_params, value_dropout_key1)
_, reward_logits = self.model.reward(zs[:-1], actions, reward_params)
if self.model.predict_continues:
continue_logits = self.model.continue_model.apply_fn(
{'params': continue_params}, zs[1:]).squeeze(-1)
reward_loss = jnp.zeros(self.batch_size)
value_loss = jnp.zeros(self.batch_size)
continue_loss = jnp.zeros(self.batch_size)
discount = jnp.ones(self.batch_size)
for t in range(self.horizon):
reward_loss += soft_crossentropy(reward_logits[t], rewards[t],
self.model.symlog_min,
self.model.symlog_max,
self.model.num_bins) * discount
if self.model.predict_continues:
continue_loss += optax.sigmoid_binary_cross_entropy(
continue_logits[t], 1 - dones[t]) * discount
for q in range(self.model.num_value_nets):
value_loss += soft_crossentropy(q_logits[q, t], td_targets[t],
self.model.symlog_min,
self.model.symlog_max,
self.model.num_bins) * discount
discount *= self.rho * (1 - dones[t])
consistency_loss = (consistency_loss / horizon).mean()
reward_loss = (reward_loss / horizon).mean()
value_loss = (value_loss / (horizon + self.model.num_value_nets)).mean()
continue_loss = (continue_loss / horizon).mean()
total_loss = (
self.consistency_coef * consistency_loss +
self.reward_coef * reward_loss +
self.value_coef * value_loss +
self.continue_coef * continue_loss
)
I still need to do some more testing and make sure I didn't accidentally break anything in the infinite horizon case, but these results are extremely promising! I'll hopefully push them to the jax repo some time this weekend.
@ShaneFlandermeyer this looks really interesting! this is similar to how i deal with early termination in the episodic-rl branch, except that i use the binarized signal during planning. i will definitely give the soft version that you're proposing a try, thanks for the suggestion!
snippets from the relevant parts:
and
Ha, nice. I didn't know you had already looked into this! I just verified and pushed all my changes. Hopefully they can be useful. Gonna close this issue now.
Hello,
In the config file in this repository, the reward scale factors are consistency/reward/value = 20/0.1/0.1, but in the paper you use 10/0.5/1.0. Is there a particular reason for this change? Have you found that the new configuration works better across a wider range of tasks?
My intuition is that giving a higher weight to the consistency loss should get the agent to a "good" latent representation more quickly, but I would greatly appreciate your thought process.