allenai / RL4LMs

A modular RL library to fine-tune language models to human preferences
https://rl4lms.apps.allenai.org/
Apache License 2.0
2.13k stars 191 forks source link

'GPT2Model' object has no attribute 'first_device' #56

Open Stephanehk opened 1 year ago

Stephanehk commented 1 year ago

I get the following error when running python scripts/training/train_text_generation.py --config_path scripts/training/task_configs/dialog/gpt2_ppo.yml. I have double-checked that transformers==4.18.0.

Traceback (most recent call last):
  File "/Users/stephanehatgiskessell/Desktop/RL4LMs/scripts/training/train_text_generation.py", line 84, in <module>
    main(
  File "/Users/stephanehatgiskessell/Desktop/RL4LMs/scripts/training/train_text_generation.py", line 55, in main
    trainer.train_and_eval()
  File "/Users/stephanehatgiskessell/Desktop/RL4LMs/rl4lms/envs/text_generation/training_utils.py", line 232, in train_and_eval
    self._alg.learn(self._n_steps_per_iter)
  File "/Users/stephanehatgiskessell/Desktop/RL4LMs/rl4lms/algorithms/ppo/ppo.py", line 342, in learn
    return super().learn(
  File "/opt/anaconda3/lib/python3.9/site-packages/stable_baselines3/common/on_policy_algorithm.py", line 247, in learn
    continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps)
  File "/Users/stephanehatgiskessell/Desktop/RL4LMs/rl4lms/envs/text_generation/alg_wrappers.py", line 384, in collect_rollouts
    rollout_info = self.generate_batch(
  File "/Users/stephanehatgiskessell/Desktop/RL4LMs/rl4lms/envs/text_generation/alg_wrappers.py", line 159, in generate_batch
    gen_output = self.policy.generate(
  File "/Users/stephanehatgiskessell/Desktop/RL4LMs/rl4lms/envs/text_generation/policy/base_policy.py", line 230, in generate
    inputs=input_ids.to(self.get_policy_first_device()),
  File "/Users/stephanehatgiskessell/Desktop/RL4LMs/rl4lms/envs/text_generation/policy/causal_policy.py", line 259, in get_policy_first_device
    self._policy_model.transformer.first_device
  File "/opt/anaconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1185, in __getattr__
    raise AttributeError("'{}' object has no attribute '{}'".format(
AttributeError: 'GPT2Model' object has no attribute 'first_device'