BlackSamorez / tensor_parallel

Automatically split your PyTorch models on multiple GPUs for training & inference
MIT License
631 stars 39 forks source link

GPT-2 broken starting in v1.2.5 #99

Closed eric-mitchell closed 1 year ago

eric-mitchell commented 1 year ago

Thanks for the cool package. As of version 1.2.5, I can't do a forward pass on GPT-2. Simple repro script:

import transformers
import tensor_parallel as tp

model = transformers.AutoModelForCausalLM.from_pretrained('gpt2-xl', cache_dir='/scr-ssd/em7')
tokenizer = transformers.AutoTokenizer.from_pretrained('gpt2-xl', cache_dir='/scr-ssd/em7')
model = tp.tensor_parallel(model)

test_input = 'I enjoy walking with my cute dog'
tokens = tokenizer(test_input, return_tensors='pt').to('cuda:0')
tokens['labels'] = tokens['input_ids'].clone()
outputs = model(**tokens)

The output on 1.2.5 is:

The following patterns in state_rules were unused: ["re.compile('.*lm_head\\\\.weight$')", "re.compile('.*q_attn\\\\.weight$')", "re.compile('.*q_attn\\\\.bias$')"]
The following patterns in state_rules were unused: ["re.compile('.*lm_head\\\\.weight$')", "re.compile('.*q_attn\\\\.weight$')", "re.compile('.*q_attn\\\\.bias$')"]
Using ZeRO-3 sharding for 464000 non tensor-parallel parameters
Traceback (most recent call last):
  File "tp_test.py", line 16, in <module>
    outputs = model(**tokens)
  File "/iris/u/em7/code/direct-preference-optimization/env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/iris/u/em7/code/direct-preference-optimization/env/lib/python3.8/site-packages/tensor_parallel/pretrained_model.py", line 78, in forward
    return self.wrapped_model(*args, **kwargs)
  File "/iris/u/em7/code/direct-preference-optimization/env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/iris/u/em7/code/direct-preference-optimization/env/lib/python3.8/site-packages/tensor_parallel/sharding.py", line 95, in forward
    return self.module(*args, **kwargs)
  File "/iris/u/em7/code/direct-preference-optimization/env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/iris/u/em7/code/direct-preference-optimization/env/lib/python3.8/site-packages/tensor_parallel/tensor_parallel.py", line 130, in forward
    return parallel_apply(self.module_shards, inputs, kwargs_tup, self.devices)[self.output_device_index]
  File "/iris/u/em7/code/direct-preference-optimization/env/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py", line 89, in parallel_apply
    output.reraise()
  File "/iris/u/em7/code/direct-preference-optimization/env/lib/python3.8/site-packages/torch/_utils.py", line 644, in reraise
    raise exception
RuntimeError: Caught RuntimeError in replica 0 on device 0.
Original Traceback (most recent call last):
  File "/iris/u/em7/code/direct-preference-optimization/env/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py", line 64, in _worker
    output = module(*input, **kwargs)
  File "/iris/u/em7/code/direct-preference-optimization/env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/iris/u/em7/code/direct-preference-optimization/env/lib/python3.8/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 1076, in forward
    transformer_outputs = self.transformer(
  File "/iris/u/em7/code/direct-preference-optimization/env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/iris/u/em7/code/direct-preference-optimization/env/lib/python3.8/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 900, in forward
    outputs = block(
  File "/iris/u/em7/code/direct-preference-optimization/env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/iris/u/em7/code/direct-preference-optimization/env/lib/python3.8/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 391, in forward
    attn_outputs = self.attn(
  File "/iris/u/em7/code/direct-preference-optimization/env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/iris/u/em7/code/direct-preference-optimization/env/lib/python3.8/site-packages/tensor_parallel/wrapper.py", line 71, in forward
    output = self.tp_wrapped_module(*args, **kwargs)
  File "/iris/u/em7/code/direct-preference-optimization/env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/iris/u/em7/code/direct-preference-optimization/env/lib/python3.8/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 313, in forward
    query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
  File "/iris/u/em7/code/direct-preference-optimization/env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/iris/u/em7/code/direct-preference-optimization/env/lib/python3.8/site-packages/transformers/pytorch_utils.py", line 106, in forward
    x = x.view(size_out)
RuntimeError: shape '[1, 7, 2496]' is invalid for input of size 16800

On version 1.2.4, I get the output:

The following patterns in state_rules were unused: ["re.compile('.*lm_head\\\\.weight$')", "re.compile('.*q_attn\\\\.weight$')", "re.compile('.*q_attn\\\\.bias$')"]
The following patterns in state_rules were unused: ["re.compile('.*lm_head\\\\.weight$')", "re.compile('.*q_attn\\\\.weight$')", "re.compile('.*q_attn\\\\.bias$')"]
Using ZeRO-3 sharding for 464000 non tensor-parallel parameters                                                                                
tensor(4.7082, device='cuda:0', grad_fn=<NllLossBackward0>)
BlackSamorez commented 1 year ago

I get the same result on my machine. I see no issues with gpt2 but gpt2-xl doesn't work. I'll look into it.