rail-berkeley / rlkit

Collection of reinforcement learning algorithms
MIT License
2.49k stars 553 forks source link

Dataset based Trainer #54

Open redknightlois opened 5 years ago

redknightlois commented 5 years ago

This example dataset based trainer also does expert signal recollection, so that is why I didnt do a PR, will let it to you to decide which parts make sense for rlkit.

class OptimizedBatchRLAlgorithm(BaseRLAlgorithm, metaclass=abc.ABCMeta):
    def __init__(
            self,
            trainer,
            exploration_env,
            evaluation_env,
            exploration_data_collector: PathCollector,
            evaluation_data_collector: PathCollector,
            replay_buffer: ReplayBuffer,
            batch_size,
            max_path_length,
            num_epochs,
            num_eval_steps_per_epoch,
            num_expl_steps_per_train_loop,
            num_trains_per_train_loop,
            num_train_loops_per_epoch=1,
            min_num_steps_before_training=0,
            max_num_steps_before_training=1e5,
            expert_data_collector: PathCollector = None,
    ):
        super().__init__(
            trainer,
            exploration_env,
            evaluation_env,
            exploration_data_collector,
            evaluation_data_collector,
            replay_buffer,
        )

        assert isinstance(replay_buffer, Dataset), "The replay buffers must be compatible with Pytorch Dataset to use this version."

        self.batch_size = batch_size
        self.max_path_length = max_path_length
        self.num_epochs = num_epochs
        self.num_eval_steps_per_epoch = num_eval_steps_per_epoch
        self.num_trains_per_train_loop = num_trains_per_train_loop
        self.num_train_loops_per_epoch = num_train_loops_per_epoch
        self.num_expl_steps_per_train_loop = num_expl_steps_per_train_loop
        self.min_num_steps_before_training = min_num_steps_before_training
        self.max_num_steps_before_training = max_num_steps_before_training
        self.expert_data_collector = expert_data_collector

    def _train(self):
        if self.min_num_steps_before_training > 0:
            init_expl_paths = self.expl_data_collector.collect_new_paths(
                self.max_path_length,
                self.min_num_steps_before_training,
                discard_incomplete_paths=False,
            )

            self.replay_buffer.add_paths(init_expl_paths)

            self.expert_data_collector.end_epoch(-1)
            self.expl_data_collector.end_epoch(-1)

        if self.expert_data_collector is not None:
            new_expl_paths = self.expert_data_collector.collect_new_paths(
                self.max_path_length,
                min(int(self.replay_buffer.max_buffer_size * 0.5), self.max_num_steps_before_training),
                discard_incomplete_paths=False,
            )
            self.replay_buffer.add_paths(new_expl_paths)

        dataset_loader = torch.utils.data.DataLoader(self.replay_buffer, pin_memory=True, batch_size=self.batch_size, num_workers=0)

        for epoch in gt.timed_for(
                range(self._start_epoch, self.num_epochs),
                save_itrs=True,
        ):
            printout('Evaluation sampling')
            self.eval_data_collector.collect_new_paths(
                self.max_path_length,
                self.num_eval_steps_per_epoch,
                discard_incomplete_paths=True,
            )
            gt.stamp('evaluation sampling')

            for _ in range(self.num_train_loops_per_epoch):

                printout('Exploration sampling')
                new_expl_paths = self.expl_data_collector.collect_new_paths(
                    self.max_path_length,
                    self.num_expl_steps_per_train_loop,
                    discard_incomplete_paths=False,
                )
                gt.stamp('exploration sampling', unique=False)

                self.replay_buffer.add_paths(new_expl_paths)
                gt.stamp('data storing', unique=False)

                self.training_mode(True)

                i = 0
                with tqdm(total=self.num_trains_per_train_loop) as pbar:
                    while True:

                        for _, data in enumerate(dataset_loader, 0):
                            if i > self.num_trains_per_train_loop:
                                break  # We are done

                            observations = data[0].to(ptu.device)
                            actions = data[1].to(ptu.device)
                            rewards = data[2].to(ptu.device)
                            terminals = data[3].to(ptu.device).float()
                            next_observations = data[4].to(ptu.device)
                            env_infos = data[5]

                            train_data = dict(
                                observations=observations,
                                actions=actions,
                                rewards=rewards,
                                terminals=terminals,
                                next_observations=next_observations,
                            )

                            for key in env_infos.keys():
                                train_data[key] = env_infos[key]

                            self.trainer.train(train_data)
                            pbar.update(1)
                            i += 1

                        if i > self.num_trains_per_train_loop:
                            break

                gt.stamp('training', unique=False)
                self.training_mode(False)

                if isinstance(self.expl_data_collector, AtariPathCollectorWithEmbedder):
                    eval_policy = self.eval_data_collector.get_snapshot()['policy']
                    self.expl_data_collector.evaluate(eval_policy)

            self._end_epoch(epoch)

    def to(self, device):
        for net in self.trainer.networks:
            net.to(device)

    def training_mode(self, mode):
        for net in self.trainer.networks:
            net.train(mode)
nm-narasimha commented 5 years ago

Thanks for defining this class.. Can you share an example how to use this trainer class along with DDPG and SAC?

redknightlois commented 5 years ago

Standard examples show how to do that. There is no difference between the current and this one. I use #52 for dataset size reasons though, but for the rest is pretty straightforward.

nm-narasimha commented 5 years ago

Thanks.. @redknightlois , do you have a sample replay_buffer compatable with pytorch dataset class? Is env_replay_buffer or any other class in rlkit.data_management is compatable?

Thanks, Narasimha

redknightlois commented 5 years ago

52 is a pytorch dataset class.

vitchyr commented 5 years ago

Hmmm, so it looks like the main difference is the addition of expert_data_collector. Is that correct? In that case, I'm not sure if we need to create an entirely new class for this. One option would be to add that data to the replay buffer before passing the replay buffer to the algorithm. What do you think of that? It would help separate out the algorithm from the pretraining phase.