rlworkgroup / garage

A toolkit for reproducible reinforcement learning research.
MIT License
1.84k stars 309 forks source link

Question about best way to use set_task() in HalfCheetahVel #2280

Closed mpiseno closed 3 years ago

mpiseno commented 3 years ago

I am brand new to garage and my understanding of Samplers and Workers is not fully there, so any additional context of what is happening behind the scenes when answering the following question would be much appreciated.

My Question: I am training an agent with SAC on HalfCheetahVel and I'm trying to call set_task() to change the goal velocity at the beginning of every epoch/episode. However, it seems like I have to modify the given SAC implementation's train function to do so. My current workaround is to define a brand new environment if a new epoch has started and pass that as the env_update parameter in trainer.obtain_samples (see below). Is there a cleaner way to accomplish the same thing? I was looking into SetTaskSampler, but it doesn't seem to be what I want because I don't want to sample a bunch of tasks, I just want to be able to set a specific new task once per epoch.

class Modified_SAC(SAC):
    def __init__(self, kwargs):
        super().__init__(**kwargs)

    def train(self, trainer):
        """Obtain samplers and start actual training for each epoch.
        Args:
            trainer (Trainer): Gives the algorithm the access to
                :method:`~Trainer.step_epochs()`, which provides services
                such as snapshotting and sampler control.
        Returns:
            float: The average return in last epoch cycle.
        """
        if not self._eval_env:
            self._eval_env = trainer.get_env_copy()
        last_return = None
        last_epoch = None
        for epoch in trainer.step_epochs():
            for _ in range(self._steps_per_epoch):
                if not (self.replay_buffer.n_transitions_stored >=
                        self._min_buffer_size):
                    batch_size = int(self._min_buffer_size)
                else:
                    batch_size = None

                new_env = None
                if epoch != last_epoch:
                    new_env = HalfCheetahVelEnv()
                    new_env.set_task(new_goal_vel(epoch))
                    last_epoch = epoch

                trainer.step_episode = trainer.obtain_samples(
                    trainer.step_itr, batch_size, env_update=new_env
                )
                path_returns = []
                for path in trainer.step_episode:
                    self.replay_buffer.add_path(
                        dict(observation=path['observations'],
                             action=path['actions'],
                             reward=path['rewards'].reshape(-1, 1),
                             next_observation=path['next_observations'],
                             terminal=np.array([
                                 step_type == StepType.TERMINAL
                                 for step_type in path['step_types']
                             ]).reshape(-1, 1)))
                    path_returns.append(sum(path['rewards']))
                assert len(path_returns) == len(trainer.step_episode)
                self.episode_rewards.append(np.mean(path_returns))
                for _ in range(self._gradient_steps):
                    policy_loss, qf1_loss, qf2_loss = self.train_once()
            last_return = self._evaluate_policy(trainer.step_itr)
            self._log_statistics(policy_loss, qf1_loss, qf2_loss)
            tabular.record('TotalEnvSteps', trainer.total_env_steps)
            trainer.step_itr += 1

        return np.mean(last_return)
krzentner commented 3 years ago

The goal velocity of HalfCheetahVelEnv isn't provided in the observation, so changing the target velocity every epoch will make the environment highly non-markovian. This tends to make SAC perform incredibly badly, since it is forced to marginalize out the target velocity. Assuming you get velocities from HalfCheetahVelEnv.sample_task, your mean target velocity will be zero, so the Q function probably won't fit at all.

If you want to just train SAC with a HalfCheetahVelEnv with a randomly varying target velocity, I would recommend writing an environment wrapper that will sample a new task in reset and change the observation to include the velocity. This isn't per-epoch, but given the typically used batch sizes for offline learning, this should be the same in expectation. (Alternatively, you can make reset switch less frequently, but that shouldn't have an effect.)

If you're trying to write a multi-task algorithm based on SAC, you will probably need to modify SAC in a way similar to above. Note that passing garage.sampler.SetTaskUpdate to obtain_samples instead of constructing a new environment will probably be a little more efficient.

Of course, most people just train SAC on HalfCheetah-v2 from OpenAI Gym, which HalfCheetahVelEnv is based on.