I'm trying to run the sentiment_tuning.pyexample with accelerate and DeepSpeed ZeRO-3, but am hitting a runtime error with the shapes of the tensors when computing the log probs:
Traceback (most recent call last):
File "/fsx/lewis/git/trl/scratch/sentiment_tuning.py", line 207, in <module>
stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/fsx/lewis/git/trl/trl/trainer/ppo_trainer.py", line 660, in step
ref_logprobs, ref_logits_or_none, _, _ = self.batched_forward_pass(
File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/fsx/lewis/git/trl/trl/trainer/ppo_trainer.py", line 916, in batched_forward_pass
logits, _, values = model(**input_kwargs)
File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/fsx/lewis/git/trl/trl/models/modeling_value_head.py", line 165, in forward
base_model_output = self.pretrained_model(
File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 1076, in forward
transformer_outputs = self.transformer(
File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 843, in forward
inputs_embeds = self.wte(input_ids)
File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/torch/nn/modules/sparse.py", line 162, in forward
return F.embedding(
File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/torch/nn/functional.py", line 2210, in embedding
Traceback (most recent call last):
File "/fsx/lewis/git/trl/scratch/sentiment_tuning.py", line 207, in <module>
stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/fsx/lewis/git/trl/trl/trainer/ppo_trainer.py", line 660, in step
ref_logprobs, ref_logits_or_none, _, _ = self.batched_forward_pass(
File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/fsx/lewis/git/trl/trl/trainer/ppo_trainer.py", line 916, in batched_forward_pass
logits, _, values = model(**input_kwargs)
File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/fsx/lewis/git/trl/trl/models/modeling_value_head.py", line 165, in forward
base_model_output = self.pretrained_model(
File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: 'weight' must be 2-D
File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 1076, in forward
transformer_outputs = self.transformer(
File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 843, in forward
inputs_embeds = self.wte(input_ids)
File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/torch/nn/modules/sparse.py", line 162, in forward
return F.embedding(
File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/torch/nn/functional.py", line 2210, in embedding
return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: 'weight' must be 2-D
Note that I've added the correct device placement for the reward model here, and this looks to be unrelated as far as I can tell. Note there is also no error with ZeRO-2, which suggests weight sharding is the problem.
I'm trying to run the
sentiment_tuning.py
example withaccelerate
and DeepSpeed ZeRO-3, but am hitting a runtime error with the shapes of the tensors when computing the log probs:Note that I've added the correct device placement for the reward model here, and this looks to be unrelated as far as I can tell. Note there is also no error with ZeRO-2, which suggests weight sharding is the problem.
Steps to reproduce
accelerate config
sentiment_tuning.py
Gist (link) or use the official example. Then run with:Expected behaviour
I can run
sentiment_tuning.py
with ZeRO-3 and no error.Env
cc @pacman100 who might have seen a similar issue in other contexts