huggingface / optimum

🚀 Accelerate training and inference of 🤗 Transformers and 🤗 Diffusers with easy to use hardware optimization tools
https://huggingface.co/docs/optimum/main/
Apache License 2.0
2.56k stars 463 forks source link

torch_ones in fx.optimisation.FuseBiasInLinear creates tensor/proxy in cpu #628

Closed IlyasMoutawwakil closed 1 year ago

IlyasMoutawwakil commented 1 year ago

System Info

Colab Environment:

- `optimum` version : 1.5.2
- `transformers` version: 4.25.1
- Platform: Linux-5.10.133+-x86_64-with-glibc2.27
- Python version: 3.8.16
- Huggingface_hub version: 0.11.1
- PyTorch version (GPU?): 1.13.0+cu116 (True)
- Using GPU in script?: True

Who can help?

@regisss

In this line here, device could be specified:

return torch.cat([linear_input, torch_ones(shape, device=linear_input.device)], dim=-1)

I tested it.

Information

Tasks

Reproduction

To reproduce this error:

import torch
from transformers import BertModel
from transformers.utils.fx import symbolic_trace

# original Model
original_bert = BertModel.from_pretrained("bert-base-uncased")
original_bert.eval()

# traced Model
traced_bert = symbolic_trace(
    original_bert,
    input_names=["input_ids", "attention_mask", "token_type_ids"],
)

# put models on cuda
original_bert.to(device)
traced_bert.to(device)

# input configuration
device = 'cuda'
bert_inputs = dict()
batch_size, seq_length = 4, 512

# create inputs
bert_inputs['input_ids'] = torch.zeros(batch_size, seq_length, dtype=torch.int, device=device).random_(original_bert.config.vocab_size)
bert_inputs['token_type_ids'] = torch.zeros(batch_size, seq_length, dtype=torch.int, device=device)
bert_inputs['attention_mask'] = torch.ones(batch_size, seq_length, dtype=torch.int, device=device)

# transform graph
transformation = FuseBiasInLinear()
transformed_bert = transformation(traced_bert)
transformed_outputs = transformed_bert(**bert_inputs)

TraceBack:

```python Traceback (most recent call last): File "/usr/local/lib/python3.8/dist-packages/torch/fx/graph_module.py", line 267, in __call__ return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc] File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1190, in _call_impl return forward_call(*input, **kwargs) File ".39", line 34, in forward cat = torch.cat([embeddings_dropout, ones], dim = -1); ones = None RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument tensors in method wrapper_cat) Call using an FX-traced Module, line 34 of the traced Module's generated forward function: ones = torch.ones(add_87); add_87 = None cat = torch.cat([embeddings_dropout, ones], dim = -1); ones = None ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE encoder_layer_0_attention_self_query = getattr(self.encoder.layer, "0").attention.self.query(cat); cat = None getattr_2 = embeddings_dropout.shape --------------------------------------------------------------------------- RuntimeError Traceback (most recent call last) [](https://localhost:8080/#) in 4 transformed_bert = composition(traced_bert) 5 transformed_bert.to(device) ----> 6 transformed_outputs = transformed_bert(**bert_inputs) 7 8 # verify outputs 1 frames [/usr/local/lib/python3.8/dist-packages/torch/fx/graph_module.py](https://localhost:8080/#) in __call__(self, obj, *args, **kwargs) 273 print(_WrappedCall._generate_error_message(topmost_framesummary), 274 file=sys.stderr) --> 275 raise e.with_traceback(None) 276 else: 277 raise e RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument tensors in method wrapper_cat) ```

Expected behavior

To run without error.

regisss commented 1 year ago

Hi Ilyas, good catch! I'm able to reproduce this error. Would you like to open a PR to fix this?

IlyasMoutawwakil commented 1 year ago

Yes ofc I would love to!