allenai / RL4LMs

A modular RL library to fine-tune language models to human preferences
https://rl4lms.apps.allenai.org/
Apache License 2.0
2.13k stars 191 forks source link

Error with Accelerate integration + NLPO #52

Open avacaondata opened 1 year ago

avacaondata commented 1 year ago

Hi, I'm trying to use the Accelerate integration, because otherwise with NLPO I cannot run a small model (200M parameter) with 512 tokens length, not even in a 80GB A100. That makes NLPO impractical for almost any problem, unless you can use Accelerate / Deepspeed or any other integration for splitting models among GPUs and CPUs. However, when trying to do so, I receive the following error:

Traceback (most recent call last):
  File "/home/alejandro.vaca/nlp_rl/RL4LMs/scripts/training/train_text_generation.py", line 95, in <module>
    main(
  File "/home/alejandro.vaca/nlp_rl/RL4LMs/scripts/training/train_text_generation.py", line 56, in main
    trainer = OnPolicyTrainer(
  File "/home/alejandro.vaca/nlp_rl/RL4LMs/rl4lms/envs/text_generation/training_utils.py", line 190, in __init__
    self._setup()
  File "/home/alejandro.vaca/nlp_rl/RL4LMs/rl4lms/envs/text_generation/training_utils.py", line 207, in _setup
    self._alg = build_alg(
  File "/home/alejandro.vaca/nlp_rl/RL4LMs/rl4lms/envs/text_generation/training_utils.py", line 151, in build_alg
    alg = wrapper(
  File "/home/alejandro.vaca/nlp_rl/RL4LMs/rl4lms/envs/text_generation/alg_wrappers.py", line 451, in wrap_onpolicy_alg
    alg = OnPolicyAlgText(alg_kwargs, kl_coeff, tracker, accelerator, target_kl, norm_reward)
  File "/home/alejandro.vaca/nlp_rl/RL4LMs/rl4lms/envs/text_generation/alg_wrappers.py", line 126, in __init__
    super().__init__(**alg_kwargs)
TypeError: __init__() got an unexpected keyword argument 'accelerator'

Also, I'm curious about how you were able to carry out the benchmarks in your paper, as I don't know any single GPU bigger than 80GB and even there I cannot run NLPO with sentences longer than 128 tokens. How did you do it? Is there maybe something I'm missing? @rajcscw @jmhessel @rajammanabrolu @JulesGM @akifumi-wachi-4

Thank you very much for this amazing work!! :)

rajcscw commented 1 year ago

@avacaondata Thanks for reporting. We have not fully migrated NLPO to use accelerate. We will fix this soon and let you know.