matsuolab / BREMEN

Codebase of Deployment-Efficient Reinforcement Learning via Model-Based Offline Optimization (ICLR2021)
https://openreview.net/forum?id=3hGNqpI4WS
51 stars 7 forks source link

Run BREMEN on D4RL #5

Open IcarusWizard opened 3 years ago

IcarusWizard commented 3 years ago

Hi. Thanks for sharing the code. I am interested in offline reinforcement learning. In Appendix D. of the paper, you show the performance of BREMEN on D4RL, but the launch script is not found in the codebase. Do you have a plan to share the script to launch d4rl experiments?

frt03 commented 3 years ago

@IcarusWizard For D4RL experiments, you need to write the following function in libs/misc/data_handling/rollout_sampler.py:

import d4rl 

    def generate_d4rl_data(self, dataset_name='hopper-medium-v0', n_train=int(1e6), horizon=1000):
        print(dataset_name)
        dataset = d4rl.qlearning_dataset(gym.make(dataset_name).env)
        # datafile: str
        s1 = dataset['observations']
        s2 = dataset['next_observations']
        a1 = dataset['actions']
        r = dataset['rewards']
        data_size = max(s1.shape[0], s2.shape[0], a1.shape[0], r.shape[0])
        n_train = min(n_train, data_size)
        paths = []
        for i in range(int(n_train/horizon)):
            path = Path()
            if i*horizon % 10000 == 0:
                print(i*horizon)
            for j in range(i*horizon, (i+1)*horizon, 1):
                obs = s1[j].tolist()
                action = a1[j].tolist()
                next_obs = s2[j].tolist()
                reward = r[j].tolist()
                path.add_timestep(obs, action, next_obs, reward)
            paths.append(path)

        return paths

and replace a part of code as follows in offline.py:

def get_data_from_offline_batch(params, env, normalization_scope=None, model='dynamics', split_ratio=0.9):
    train_collection = DataCollection(
        batch_size=params[model]['batch_size'],
        max_size=params['max_train_data'], shuffle=True)
    val_collection = DataCollection(batch_size=params[model]['batch_size'],
                                    max_size=params['max_val_data'],
                                    shuffle=False)
    rollout_sampler = RolloutSampler(env)
    # rl_paths = rollout_sampler.generate_offline_data(
    #        data_file=params['data_file'],
    #        n_train=params["n_train"]
    #    )
    rl_paths = rollout_sampler.generate_d4rl_data(
            dataset_name=params['data_file'],
            n_train=params["n_train"]
        )
    path_collection = PathCollection()
    obs_dim = env.observation_space.shape[0]
    normalization = add_path_data_to_collection_and_update_normalization(
        rl_paths, path_collection,
        train_collection, val_collection,
        normalization=None,
        split_ratio=split_ratio,
        obs_dim=obs_dim,
        normalization_scope=normalization_scope)
    return train_collection, val_collection, normalization, path_collection, rollout_sampler

You also need to add the --data_file args and comment out a part of params_processing.py.

Because D4RL is an additional experiment, the source code is quite dirty. I hope this part of the code would help you.

IcarusWizard commented 3 years ago

Hi, @frt03 . Thanks for your help. I have got it to work.

There are additional changes to be made. D4RL requires the latest version of gym and mujoco_py which is incompatible with the environments in this repo. For all the environments defined in envs/gym, I have to rename _step to step, self.model to self.sim in _get_obs.

IcarusWizard commented 3 years ago

I have an additional question with respect to the performance. I have run the code on halfcheetah-medium, hopper-medium, walker2d-medium with the hyper-parameters in readme , and got the performance of 50.2, 35.7, 13.4 respectively at the last training iteration. The performance is quite different from the numbers reported in the paper, especially for the task with a terminal function. I wonder if there is something missing in my modification or used the wrong hyper-paramters and random seeds? What should I do to reproduce the result in the paper?

Moreover, I notice that the test is performed at each iteration with only 3000 steps, which may not enough to evaluate the performance on hopper and walker2d.