huggingface / trl

Train transformer language models with reinforcement learning.
http://hf.co/docs/trl
Apache License 2.0
9.32k stars 1.17k forks source link

DeepSpeed ZeRO-3 throws `RuntimeError: 'weight' must be 2-D` for `sentiment_tuning.py` #600

Closed lewtun closed 1 year ago

lewtun commented 1 year ago

I'm trying to run the sentiment_tuning.py example with accelerate and DeepSpeed ZeRO-3, but am hitting a runtime error with the shapes of the tensors when computing the log probs:

Traceback (most recent call last):
  File "/fsx/lewis/git/trl/scratch/sentiment_tuning.py", line 207, in <module>
    stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/fsx/lewis/git/trl/trl/trainer/ppo_trainer.py", line 660, in step
    ref_logprobs, ref_logits_or_none, _, _ = self.batched_forward_pass(
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/fsx/lewis/git/trl/trl/trainer/ppo_trainer.py", line 916, in batched_forward_pass
    logits, _, values = model(**input_kwargs)
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/fsx/lewis/git/trl/trl/models/modeling_value_head.py", line 165, in forward
    base_model_output = self.pretrained_model(
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 1076, in forward
    transformer_outputs = self.transformer(
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 843, in forward
    inputs_embeds = self.wte(input_ids)
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/torch/nn/modules/sparse.py", line 162, in forward
    return F.embedding(
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/torch/nn/functional.py", line 2210, in embedding
Traceback (most recent call last):
  File "/fsx/lewis/git/trl/scratch/sentiment_tuning.py", line 207, in <module>
    stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/fsx/lewis/git/trl/trl/trainer/ppo_trainer.py", line 660, in step
    ref_logprobs, ref_logits_or_none, _, _ = self.batched_forward_pass(
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/fsx/lewis/git/trl/trl/trainer/ppo_trainer.py", line 916, in batched_forward_pass
    logits, _, values = model(**input_kwargs)
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/fsx/lewis/git/trl/trl/models/modeling_value_head.py", line 165, in forward
    base_model_output = self.pretrained_model(
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: 'weight' must be 2-D
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 1076, in forward
    transformer_outputs = self.transformer(
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 843, in forward
    inputs_embeds = self.wte(input_ids)
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/torch/nn/modules/sparse.py", line 162, in forward
    return F.embedding(
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/torch/nn/functional.py", line 2210, in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: 'weight' must be 2-D

Note that I've added the correct device placement for the reward model here, and this looks to be unrelated as far as I can tell. Note there is also no error with ZeRO-2, which suggests weight sharding is the problem.

Steps to reproduce

  1. Create a ZeRO-3 config with accelerate config
# config_ds3.yaml
compute_environment: LOCAL_MACHINE
deepspeed_config:
  gradient_accumulation_steps: 1
  gradient_clipping: 1.0
  offload_optimizer_device: none
  offload_param_device: none
  zero3_init_flag: true
  zero3_save_16bit_model: true
  zero_stage: 3
distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 2
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
  1. Download either my sentiment_tuning.py Gist (link) or use the official example. Then run with:
accelerate launch --config_file config_ds3.yaml sentiment_tuning.py --log_with="wandb"

Expected behaviour

I can run sentiment_tuning.py with ZeRO-3 and no error.

Env

- `transformers` version: 4.31.0
- Platform: Linux-5.15.0-1023-aws-x86_64-with-glibc2.31
- Python version: 3.10.10
- Huggingface_hub version: 0.16.4
- Safetensors version: 0.3.1
- Accelerate version: 0.21.0
- DeepSpeed version: 0.9.5
- Accelerate config:    - compute_environment: LOCAL_MACHINE
        - distributed_type: MULTI_GPU
        - mixed_precision: bf16
        - use_cpu: False
        - num_processes: 8
        - machine_rank: 0
        - num_machines: 1
        - gpu_ids: all
        - rdzv_backend: static
        - same_network: True
        - main_training_function: main
        - downcast_bf16: no
        - tpu_use_cluster: False
        - tpu_use_sudo: False
        - tpu_env: []
- PyTorch version (GPU?): 2.0.0+cu117 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using GPU in script?: yes
- Using distributed or parallel set-up in script?: distributed

cc @pacman100 who might have seen a similar issue in other contexts

YooSungHyun commented 1 year ago

me too, if using just optimizer offload, same error accured

YooSungHyun commented 1 year ago

https://github.com/huggingface/trl/issues/669