srzer / LaMo-2023

Official code for "Unleashing the Power of Pre-trained Language Models for Offline Reinforcement Learning".
https://lamo2023.github.io
MIT License
37 stars 9 forks source link

More details about reproducing the experiment #1

Closed aopolin-lv closed 7 months ago

aopolin-lv commented 8 months ago

Hi, authors, congratulations for LaMo being accepted by ICLR 2024. It is amazing that this work injects the pretraining knowledge into the original decision transformer and using the lora to finetune the network efficiently. Here, I met some questions when reproducing the experiment result. Could you provide more details?

The questions are followed:

  1. The hyperparameter settings in the given script run.sh are not consistent with the noted default ones. Could you provide more details about the hyperparameters of different dataset? Such as a table of hyperparameters (lr, lmlr, weight_decay, warmup_steps, num_steps_per_iter, max_iters, num_eval_episodes) set in the experiment.
  2. How to compute the scores reported in the paper? It is mentioned that To systematically report the performance of all these methods, we compute the average performance over the last 20K training steps out of a total of 100K training steps with evaluations conducted every 2500 training steps.
  3. The official decision tranformer using the environment Hopper-v3, HalfCheetah-v3, Walker2d-v3 , while this work using those of hopper-medium-v2, halfcheetah-medium-v2, walker2d-medium-v2. What is the difference between these two settings?

More interesting, I found that it did not count the number of params in prediction layers. And if includes these params, the ratio of trainable params will exceeds 7% reported in the paper.

srzer commented 8 months ago

Thank you for your issue! Questions are answered as follow:

  1. The script run.sh is an example script and we provide a hyperparameter table in appendix: hyperparameters, table 6,7 and you could refer to our camera-ready version for this.
  2. Yes, you could follow the instruction you quoted. We would consider providing more detailed descriptions about this or releasing the data processing scripts.
  3. https://www.gymlibrary.dev/environments/mujoco/half_cheetah/#version-history shows the difference, no influence on the evaluation. Additionally, DT uses v2 data, as shown in https://github.com/kzl/decision-transformer/blob/master/gym/data/download_d4rl_datasets.py#L14. Thus using V2 env for evaluation is a common practice.
  4. Yes, it's true that "it did not count the number of params in prediction layers". We talk about $\frac{\text{trainable parameters in Transformer}}{\text{parameters in Transformer}}$, which is 0.728017% for d4rl tasks, as the new predictor is not herited from LM. And we state this point in section 4.2, "Meanwhile, the model is desired to maintain the knowledge of the LMs. The number of trainable parameters only takes up 0.7% of the entire Transformer. "

It is also welcomed to contact me (srz21@mails.tsinghua.edu.cn) for more detailed questions! We would try our best to help you for reproducing the experiment.

aopolin-lv commented 8 months ago

Thanks for your reponse. By the way, I have other questions:

  1. how do you set num_eval_episodes, 20 or 100?
  2. And is always the co_lambda set to 0.1?
  3. What is the meaning of the hyperparameter of Return-to-go in table 6? How can I modified the setting in the code?

Many thanks for your patient response.

srzer commented 8 months ago
  1. Setting num_eval_episodes as 20 is to save evaluation time :)
  2. In most cases, it is. For each task, we report the best result with co_lambda picking among {0,0.1,1}. For MuJoCo and Atari tasks, setting it as 0 is good enough.
  3. The Return-to-go provided is a list, and we would test all of them during evaluation. You can set it in https://github.com/srzer/LaMo-2023/blob/main/experiment-d4rl/experiment.py#L57.
aopolin-lv commented 8 months ago
  1. The score reportted in paper is under the setting num_eval_episodes=100. Is it right? In addition, the evaluation really takes too much time and is there any solution to speed up?
  2. In my opinion, if the co_lambda is set to 0, the gradient coming from the next token prediction task of natural language will not update the weight of the whole neural network. And thus this setting would not take effect. I don't know if this understanding is correct. Could you explain more about why to set the co_lambda to 0, and how it take effect?
  3. Got it absolutely.

    Thanks again.

srzer commented 8 months ago

Thank you for your issue!

  1. We adopt num_eval_episodes=20 in experiments for efficiency, and has validated on representative tasks to make sure that this change has little influence on final results. Unfortunately, at the moment, I don't know how to speed up the evaluation.
  2. Yes, your understanding is correct. The auxiliary loss is intended to stabilize the training process. As for atari and d4rl tasks except for Kitchen (which is difficult), our framework without the auxiliary loss is empirically powerful enough to obtain high scores and prevent from overfitting, thus we could set co_lambda as 0. We state this point in limitation part of our paper.
aopolin-lv commented 8 months ago

Thank you. I have tried to reproduce the experimental result. However, the result of hopper-medium under the setting of sample-ratio=1 is far away from that reported in the paper. And it seems that some other results are not consistent with those in papers.

Could you give me some advice? Or could you provide your .ckpt files?

env dataset rato LaMo LaMo w/o cotrain Reimplement LaMo Reimplement LaMo w/o cotrain
Dense Reward
hopper medium 1 74.1 60.9 46.04 46.73
halfcheetah medium 1 42.5 42.6 42.5 42.4
walker2d medium 1 73.3 70.2 70.36 72.4
Sparse Reward
kitchen partial 1 46.6 33.8 30.34 42.19
complete 1 64.2 52.8 60.56 59.78
reacher medium 1 33 22.8 26.62 27.51
srzer commented 8 months ago

Happy to see you run the experiments, your efforts help us make this project easier to use!

Firstly, we kindly point out a mistake in your table, that the name of the 5th column shoule be "DT" instead of Lamo w/o cotrain (LaMo w.o. cotrain isn't simply DT, it involves LoRA and MLP embeddings as well). So the results of the 5th column shouldn't align with the results of the 7th column.

Hopper is a special task that its state representation is so simple that the Linear embedding already work, and the complicated embedding generated by MLP would violate the performance, as we stated in Appendix: IMPLEMENTATION DETAILS Network architecture for LaMo. We assure that this issue didn't appear in other tasks. As for Kitchen partial (100%), as we stated in Figure 10 in our paper, GPT2-medium slightly surpasses GPT2-small, and we report the results using GPT2-medium.

The inconsistency could be partially attributed to the slight difference in the way of calculating the final score or due to some randomness that are not controlled by the seed during evaluation.

For simplicity and efficiency, here we provide the hyperparameter we used for those tasks which exhibit statistical gap compared with our reported values in your reimplementation.

The specific hyperparameter we adopt for hopper-medium (100%) is:

--K 20 -lr 1e-4 -lmlr 1e-5 --num_steps_per_iter 2500 --weight_decay 1e-5 --max_iters 40 --num_eval_episodes 20 --sample_ratio 1 --warmup_steps 2500 --pretrained_lm gpt2 --adapt_mode 1  --adapt_embed 1 --lora 1 --mlp_embedding 0 --dropout 0.1

The specific hyperparameter we adopt for reacher2d-medium (100%) is:

--K 5 -lr 1e-5 -lmlr 1e-5 --num_steps_per_iter 2500 --weight_decay 1e-4 --max_iters 40 --num_eval_episodes 20 --sample_ratio 1 --warmup_steps 2500 --pretrained_lm gpt2 --adapt_mode 1 --adapt_embed 1 --lora 1 --mlp_embedding 1 --dropout 0.1

The specific hyperparameter we adopt for kitchen-partial (100%) is:

--K 20 -lr 1e-4 -lmlr 1e-5 --num_steps_per_iter 2500 --weight_decay 1e-5 --max_iters 40 --num_eval_episodes 20 --sample_ratio 1 --warmup_steps 2500 --pretrained_lm gpt2-medium --adapt_mode 1 --adapt_embed 1 --lora 1 --mlp_embedding 1 --dropout 0.1 --co_training --co_lambda 0.1

The specific hyperparameter we adopt for kitchen-complete (100%) is:

--K 20 -lr 1e-4 -lmlr 1e-5 --num_steps_per_iter 2500 --weight_decay 1e-4 --max_iters 40 --num_eval_episodes 20 --sample_ratio 1 --warmup_steps 2500 --pretrained_lm gpt2 --adapt_mode 1 --adapt_embed 1 --lora 1 --mlp_embedding 1 --dropout 0.1
aopolin-lv commented 8 months ago

Thank you for your paitent response. It does help me a lot.

Firstly, it is exactly the method of Lamo w/o cotrain is not equal to DT, and thank you for pointing it out. In addition, for the experimental setting, I find that the weight_decay of reacher2d-medium and kitchen-partial in the paper is $1\times10^{-5}$, which is not consistent with that in your suggestions, i.e., $1\times10^{-4}$. Meanwhile, the warmup_steps I set previously is 10000 as defult value in the script. And in your given hyperparmeters, it is supposed to be set to 2500.

I will try it again, and thanks again for sharing the important experience with me.

aopolin-lv commented 8 months ago

Hi, I have restrictly followed the setting you provided for hopper-medium (100%) with the seed of {0, 1, 2}. However, the average result is 53.5 which is still far away from that in your paper 74.1. Could you release the weight of the model?

--K 20 -lr 1e-4 -lmlr 1e-5 --num_steps_per_iter 2500 --weight_decay 1e-5 --max_iters 40 --num_eval_episodes 20 --sample_ratio 1 --warmup_steps 2500 --pretrained_lm gpt2 --adapt_mode 1 --adapt_embed 1 --lora 1 --mlp_embedding 0 --dropout 0.1

By the way, the env_targets in the code is more than that in paper. I only compute the result with the return-to-go conditioning written in paper. Is it right? And how the decision transformer evaluate?

srzer commented 8 months ago

Yes, we could provide the weights, please wait for some time. Before that, could you provide a figure of the learning curves if you used wandb? as that could let us know more details about your results. That's alright. Sorry I can't exactly understand the question "And how the decision transformer evaluate?", could you rephrase that? Thank you!

srzer commented 8 months ago

We would calculate the final score for each return-to-go respectively, and report the one with highest score (for Hopper-medium (100%), we thus only report the final score with return-to-go=3600.) We didn't state this point explicitly, as it is a common practice in Return-Conditioned Supervised Learning. Did you follow this way?

srzer commented 8 months ago

There is also a concern: the script has been aggressively cleaned for releasing, and thus the right way to remove mlp embedding is to not enter --mlp_embedding, instead of entering --mlp_embedding 0. I state the latter when I provide the hyperparameters due to my carelessness, sorry for that if you was misdirected.

aopolin-lv commented 8 months ago

reproduce_hopper_medium_seed0 reproduce_hopper_medium_seed1 reproduce_hopper_medium_seed2

Pictures above are the learning curves under the seed of {0, 1, 2} seperately.

Is the results of decision transformer reported in your paper is evaluated according to the same the way mentioned before? (To systematically report the performance of all these methods, we compute the average performance over the last 20K training steps out of a total of 100K training steps with evaluations conducted every 2500 training steps)

And do you know what's the evaluation method decision transformer used in its original paper.

In addition, thank you for pointing out the use of the hyperparameter, --mlp embedding, which is how I did in my experiment before you mentioned it.

srzer commented 8 months ago

Thank you for providing this. Yes, these curves are reasonable if applying --mlp_embedding, and removing that would help increase the performance in Hopper task. Sorry for the confusion again. Yes, we report all baselines in the same way. I don't know the exact evaluation method used in decision transformer. To best of my knowledge, they didn't state it explicitly in the original paper. We set our metrics in that way because in the context of offline RL, the RL agent could never interact with the environment during training,and thus mitigating overfitting is a crucial aspect of a method's effectiveness. This choice aligns with our perspective on the importance of robustness in the evaluation of offline RL methods.

aopolin-lv commented 8 months ago

Thank you for providing this. Yes, these curves are reasonable if applying --mlp_embedding, and removing that would help increase the performance in Hopper task. Sorry for the confusion again. Yes, we report all baselines in the same way. I don't know the exact evaluation method used in decision transformer. To best of my knowledge, they didn't state it explicitly in the original paper. We set our metrics in that way because in the context of offline RL, the RL agent could never interact with the environment during training,and thus mitigating overfitting is a crucial aspect of a method's effectiveness. This choice aligns with our perspective on the importance of robustness in the evaluation of offline RL methods.

Sorry for confusing. These curves are the result without entering --mlp_embedding.

The updated experiment is as followed. It is still a little different from that in paper. env dataset rato LaMo DT Reimplement LaMo (before) Reimplement LaMo (now)
Dense Reward
hopper medium 1 74.1 60.9 46.04 60.92
halfcheetah medium 1 42.5 42.6 42.5 (in training)
walker2d medium 1 73.3 70.2 70.36 68.26
Sparse Reward
kitchen partial 1 46.6 33.8 30.34 44.84
complete 1 64.2 52.8 60.56 64.84
reacher medium 1 33 22.8 26.62 31.15
srzer commented 8 months ago

I have re-conduct the experiment on hopper-medium (100%) again, and the curve is as below: 1709792258875which aligns with 74.1. Could you show me your script?

aopolin-lv commented 8 months ago

Could you please provide the seed?

My execution command is as followed: --K 20 -lr 1e-4 -lmlr 1e-5 --num_steps_per_iter 2500 --weight_decay 1e-5 --max_iters 40 --num_eval_episodes 20 --sample_ratio 1 --warmup_steps 2500 --pretrained_lm gpt2 --adapt_mode --adapt_embed --lora --dropout 0.1

srzer commented 8 months ago

The seed corresponding to that cruve is 0. It looks good, and the only difference seems that you directly use execution command to run experiment.py instead of using a script? If so, what about using the same script I used to reproduce. I've attached it here. Please try bash run_hopper.sh hopper medium 1 reproduce [the seed number] [your_gpu_id]. I believe it would work.

aopolin-lv commented 8 months ago

Hello, I have tried but failed. Which type of GPU you uese?

By the way, could you please give me some suggestions about such following questions:

  1. I use the research2d data you provide in this repo. How can I download this data expect this way (it is not mentioned in the decision transformer repo or code)?
  2. I use the env.render to successfully visualize the scene of mujoco except kitchen. How can I visualize it?
srzer commented 8 months ago

We use A40.

  1. The decision transformer paper mentions reacher, while they don't provide the data in their repo. Our data is generated by ourselves, using a PPO agent.
  2. Yes, kitchen couldn't directly use .render() for visualization , as the certain offline environment bans this. To visualize it in d4rl, all render() in the environment package should be modified, and we haven’t looked into this in detals.
Victor-wang-902 commented 7 months ago

Congratulations on your ICLR acceptance and appreciate your detailed explanation! Looks like for the return-to-go conditioning, you have made some slight modifications compared to the original DT paper. Can you tell me the exact rtg value (with the highest final score) for each environment? Since in the paper there are two values instead of one for each experiment?

srzer commented 7 months ago

We follow the common practice of DT when using multiple values as rtg, as shown in https://github.com/kzl/decision-transformer/blob/master/gym/experiment.py#L41. The values we adopt are validated by experiments. And one of the reasons is that for different sampling ratio, the best rtg would be different. In experiments, we've tried more than $2$ rtg values, and sometimes different rtgs would perform samely. I would manually write down the rtg with highest score in experiments for you.

Env Sampling ratio Best rtg
Hopper 1 3600
Hopper 0.1 3600
Hopper 0.01 2200
Hopper 0.005 3600
Walker2d 1 5000
Walker2d 0.1 4000
Walker2d 0.01 2500
Walker2d 0.005 2500
HalfCheetah 1 6000
HalfCheetah 0.1 6000
HalfCheetah 0.01 6000
HalfCheetah 0.005 6000
Kitchen Complete 1 3
Kitchen Complete 0.5 4
Kitchen Complete 0.3 4
Kitchen Partial 1 3
Kitchen Partial 0.1 3
Kitchen Partial 0.01 3
Reacher2d 1 40
Reacher2d 0.3 40
Reacher2d 0.1 40
Victor-wang-902 commented 7 months ago

We follow the common practice of DT when using multiple values as rtg, as shown in https://github.com/kzl/decision-transformer/blob/master/gym/experiment.py#L41. The values we adopt are validated by experiments. And one of the reasons is that for different sampling ratio, the best rtg would be different. In experiments, we've tried more than 2 rtg values, and sometimes different rtgs would perform samely. I would manually write down the rtg with highest score in experiments for you.

Env Sampling ratio Best rtg Hopper 1 3600 Hopper 0.1 3600 Hopper 0.01 2200 Hopper 0.005 3600 Walker2d 1 5000 Walker2d 0.1 4000 Walker2d 0.01 2500 Walker2d 0.005 2500 HalfCheetah 1 6000 HalfCheetah 0.1 6000 HalfCheetah 0.01 6000 HalfCheetah 0.005 6000 Kitchen Complete 1 3 Kitchen Complete 0.5 4 Kitchen Complete 0.3 4 Kitchen Partial 1 3 Kitchen Partial 0.1 3 Kitchen Partial 0.01 3 Reacher2d 1 40 Reacher2d 0.3 40 Reacher2d 0.1 40

Thank you so much. In the original DT paper, however, there was only one reward conditioning in the paper despite the code having multiple targets. I think they used whatever (single value) was in the paper for the final results. Therefore, I think it would be great if you can also mention the validation process in the hyperparameter section of the paper.

Anyways, this is great information, thanks a lot!

srzer commented 7 months ago

Thank u for point it out, we would emphasize this in our camera-ready version!

egg-west commented 7 months ago

It seems the reproduced results are not yet close to the reported results so I restarted this issue to provide some results. I am using the hyper-parameters mentioned below and I am running the run_hopper.sh.

The results I got on hopper-medium-v2 100% data are aligned with the reproduced results from @aopolin-lv . The seed I used is 0 and the maximum averaged return is 68.9 (reported 74.1 ± 5.3)

image

Could you please provide the seed?

My execution command is as followed: --K 20 -lr 1e-4 -lmlr 1e-5 --num_steps_per_iter 2500 --weight_decay 1e-5 --max_iters 40 --num_eval_episodes 20 --sample_ratio 1 --warmup_steps 2500 --pretrained_lm gpt2 --adapt_mode --adapt_embed --lora --dropout 0.1

I also tried 10% data (seed 0) and the results are here, the maximum return is 64.1 (reported 73.7 ± 3.5)

image
srzer commented 7 months ago

Thank you for reopening this issue. I will re-re-conduct the experiments from scratch to figure out if there is any missalignment and respond to you ASAP!

srzer commented 7 months ago

Hi! We have figured out that this issue is due to differences in the versions of certain packages. We will update the README in 2 days. Thank you for pointing this out! @egg-west @aopolin-lv

srzer commented 7 months ago

README updated. You are welcome to configure the environment following the README and try again.

aopolin-lv commented 7 months ago

Thank you for the update. I will try it later.

egg-west commented 7 months ago

Thank you @srzer for solving this issue. My latest results with the updated environmental configuration reproduced the reported results for the Hopper medium tasks.

image

where the blue line is the previous reproducing results and the purple line indicates the latest run. The sample ratio is 1.0 and the seed is set to 0.