huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
134.52k stars 26.9k forks source link

RuntimeError when performing Mask2Former traced model inference on a different device #33443

Closed matteot11 closed 2 weeks ago

matteot11 commented 1 month ago

System Info

Who can help?

@amyeroberts

Information

Tasks

Reproduction

import torch
from transformers import Mask2FormerForUniversalSegmentation

device = torch.device("cuda:0")
model = Mask2FormerForUniversalSegmentation.from_pretrained(
    "facebook/mask2former-swin-tiny-coco-instance", torchscript=True
)
model.eval().to(device)
dummy_input = torch.randn((1, 3, 640, 640)).to(device)
traced_model = torch.jit.trace(model, dummy_input)

# on cuda:0 everything works as expected
model(dummy_input) # <--OK!
traced_model(dummy_input) # <--OK!

device = torch.device("cuda:1")
model.to(device)(dummy_input.to(device)) <--OK!
traced_model.to(device)(dummy_input.to(device)) # <-- # RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0!

The code above generates the following error:

Traceback (most recent call last):
  File "test_mask2former_export.py", line 21, in <module>
    traced_model.to(device)(torch.rand((2, 3, 640, 640)).to(device))
  File "~/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
~/python3.12/site-packages/transformers/models/swin/modeling_swin.py(524): forward
~/python3.12/site-packages/torch/nn/modules/module.py(1543): _slow_forward
~/python3.12/site-packages/torch/nn/modules/module.py(1562): _call_impl
~/python3.12/site-packages/torch/nn/modules/module.py(1553): _wrapped_call_impl
~/python3.12/site-packages/transformers/models/swin/modeling_swin.py(593): forward
~/python3.12/site-packages/torch/nn/modules/module.py(1543): _slow_forward
~/python3.12/site-packages/torch/nn/modules/module.py(1562): _call_impl
~/python3.12/site-packages/torch/nn/modules/module.py(1553): _wrapped_call_impl
~/python3.12/site-packages/transformers/models/swin/modeling_swin.py(720): forward
~/python3.12/site-packages/torch/nn/modules/module.py(1543): _slow_forward
~/python3.12/site-packages/torch/nn/modules/module.py(1562): _call_impl
~/python3.12/site-packages/torch/nn/modules/module.py(1553): _wrapped_call_impl
~/python3.12/site-packages/transformers/models/swin/modeling_swin.py(789): forward
~/python3.12/site-packages/torch/nn/modules/module.py(1543): _slow_forward
~/python3.12/site-packages/torch/nn/modules/module.py(1562): _call_impl
~/python3.12/site-packages/torch/nn/modules/module.py(1553): _wrapped_call_impl
~/python3.12/site-packages/transformers/models/swin/modeling_swin.py(869): forward
~/python3.12/site-packages/torch/nn/modules/module.py(1543): _slow_forward
~/python3.12/site-packages/torch/nn/modules/module.py(1562): _call_impl
~/python3.12/site-packages/torch/nn/modules/module.py(1553): _wrapped_call_impl
~/python3.12/site-packages/transformers/models/swin/modeling_swin.py(1370): forward
~/python3.12/site-packages/torch/nn/modules/module.py(1543): _slow_forward
~/python3.12/site-packages/torch/nn/modules/module.py(1562): _call_impl
~/python3.12/site-packages/torch/nn/modules/module.py(1553): _wrapped_call_impl
~/python3.12/site-packages/transformers/models/mask2former/modeling_mask2former.py(1389): forward
~/python3.12/site-packages/torch/nn/modules/module.py(1543): _slow_forward
~/python3.12/site-packages/torch/nn/modules/module.py(1562): _call_impl
~/python3.12/site-packages/torch/nn/modules/module.py(1553): _wrapped_call_impl
~/python3.12/site-packages/transformers/models/mask2former/modeling_mask2former.py(2269): forward
~/python3.12/site-packages/torch/nn/modules/module.py(1543): _slow_forward
~/python3.12/site-packages/torch/nn/modules/module.py(1562): _call_impl
~/python3.12/site-packages/torch/nn/modules/module.py(1553): _wrapped_call_impl
~/python3.12/site-packages/transformers/models/mask2former/modeling_mask2former.py(2498): forward
~/python3.12/site-packages/torch/nn/modules/module.py(1543): _slow_forward
~/python3.12/site-packages/torch/nn/modules/module.py(1562): _call_impl
~/python3.12/site-packages/torch/nn/modules/module.py(1553): _wrapped_call_impl
~/python3.12/site-packages/torch/jit/_trace.py(1275): trace_module
~/python3.12/site-packages/torch/jit/_trace.py(695): _trace_impl
~/python3.12/site-packages/torch/jit/_trace.py(1000): trace
test_mask2former_export.py(10): <module>
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0!

The same happens whenever the traced Mask2Former inference is run on a different device than the one it was traced with, e.g. if device = torch.device("cpu") in line 4 and device = torch.device("cuda:0") in line 16.

Expected behavior

I would expect inference to correctly work even if performed on a different device. E.g. it could be that GPU is not available on the system where jit.trace() is called, while being available on the system where inference is run. Moreover, in a multi-GPU system, it could be necessary to inference the traced model on a device different than the one used for tracing (e.g. "cuda:1" instead of "cuda:0"). As of now, it is possible to inference the traced model only if moved to the same exact device used for tracing.

Any help on this would be greatly appreciated.

SangbumChoi commented 1 month ago

This is because the some layer from your compiled in cuda:0 however you are running in cuda:1 or vice-versa. Compilation logic inside layer with .to(device) is not working :)

This kind of error can be resolved by https://github.com/huggingface/transformers/pull/33412 this kind of PR which is avoiding pinning GPU device.

matteot11 commented 1 month ago

Hi, thanks for your reply. Do you mean that Mask2Former has some layer on which the .to(device) has no effect? If so, why the error happens only for the compiled model, and not for the base model? Could you please point me what changes in the mentioned PR should fix the issue? It seems that it is about RT_DETR, do you mean that similar fix should be performed on Mask2Former? Thanks again.

SangbumChoi commented 1 month ago

@matteot11 Yeah Let me explain one by one.

Do you mean that Mask2Former has some layer on which the .to(device) has no effect?

Compiled version is already sticked into the layer level so .to(device) has no effect. This is also why this happens in only compiled version

Could you please point me what changes in the mentioned PR should fix the issue? It seems that it is about RT_DETR, do you mean that similar fix should be performed on Mask2Former?

https://paulbridger.com/posts/mastering-torchscript/#device-pinning My recommendation is re-compile to cuda:1 for

# Move the original (non-compiled) model to GPU 1
model = model.to('cuda:1')

# Recompile for GPU 1
compiled_model_gpu1 = torch.compile(model)

# Run on GPU 1
with torch.cuda.device(1):
    output = compiled_model_gpu1(input_data)
matteot11 commented 1 month ago

So if I understood correctly, tensors created at inference time will have their device fixed after tracing. However, for other models like SegformerForSemanticSegmentation, the .to(device) works as expected on the traced model, maybe because there are no tensors created on the fly.

Looking here it seems that a way to solve the issue is replacing the tensors created at runtime with Parameters. In the PR you mentioned before I noticed that tensors continue to be created at runtime, with the only addition of the device argument. In my understanding this would not solve the problem, right?

So, at the moment, the only way to run traced Mask2Former inference on a cluster of (let's say) 4 GPUs, is to export 4 different models, one for each device?

SangbumChoi commented 1 month ago

So if I understood correctly, tensors created at inference time will have their device fixed after tracing. However, for other models like SegformerForSemanticSegmentation, the .to(device) works as expected on the traced model, maybe because there are no tensors created on the fly.

Yeah, maybe because there are no tensors created on the fly. This is the main point.

Looking here it seems that a way to solve the issue is replacing the tensors created at runtime with Parameters. In https://github.com/huggingface/transformers/pull/33412 I noticed that tensors continue to be created at runtime, with the only addition of the device argument. In my understanding this would not solve the problem, right?

You are right this is not the solution for the problem. I have just attached the link to explain about that if you want to solve it the main problem not workaround, then you might need to transition some module.

So, at the moment, the only way to run traced Mask2Former inference on a cluster of (let's say) 4 GPUs, is to export 4 different models, one for each device?

Maybe yes? I haven't tried :)

matteot11 commented 1 month ago

Ok thanks for your help. I would keep the issue open if possible, in case someone else has any other hint on this.

qubvel commented 1 month ago

Hey @matteot11 and @SangbumChoi! Thanks for raising the issue and the discussion! If you manage to find a fix rather than a workaround, please let me know! Contributions are also very welcome 🤗

matteot11 commented 1 month ago

As supposed, after investigating the torchscript inference traceback, it seems that the error is due to some tensors initialized at forward time. I list all the tensors causing device pinning: https://github.com/huggingface/transformers/blob/174890280b340b89c5bfa092f6b4fb0e2dc2d7fc/src/transformers/models/swin/modeling_swin.py#L651 https://github.com/huggingface/transformers/blob/1027a532c5435b6116feba299f2cfad66b93c2c4/src/transformers/models/mask2former/modeling_mask2former.py#L861 https://github.com/huggingface/transformers/blob/1027a532c5435b6116feba299f2cfad66b93c2c4/src/transformers/models/mask2former/modeling_mask2former.py#L870 https://github.com/huggingface/transformers/blob/1027a532c5435b6116feba299f2cfad66b93c2c4/src/transformers/models/mask2former/modeling_mask2former.py#L1102 https://github.com/huggingface/transformers/blob/1027a532c5435b6116feba299f2cfad66b93c2c4/src/transformers/models/mask2former/modeling_mask2former.py#L1103 https://github.com/huggingface/transformers/blob/1027a532c5435b6116feba299f2cfad66b93c2c4/src/transformers/models/mask2former/modeling_mask2former.py#L1297 https://github.com/huggingface/transformers/blob/1027a532c5435b6116feba299f2cfad66b93c2c4/src/transformers/models/mask2former/modeling_mask2former.py#L1303

In order to solve the issue, I tried to initialize those tensors in the corresponding nn.Module init(), using self.register_buffer, and then using those buffers at forward time. However, the shape of some of those tensors depend on the shape of the forward's input: the only way I found to make it work was to register the buffer using a minimal shape, and then repeating it several times based on the input shape. For instance, the following:

class SwinLayer(nn.Module):
    def __init__(self, config, dim, input_resolution, num_heads, shift_size=0):
        ...

    def get_attn_mask(self, height, width, dtype, device):
        ...
        img_mask = torch.zeros((1, height, width, 1), dtype=dtype, device=device)
        ...

becomes:

class SwinLayer(nn.Module):
    def __init__(self, config, dim, input_resolution, num_heads, shift_size=0):
        ...
        self.register_buffer("img_mask", torch.zeros((1, 1, 1, 1), dtype=torch.float32))

    def get_attn_mask(self, height, width):
        ...
        img_mask = self.img_mask.repeat(1, height, width, 1)

Maybe it's not the best solution, but in this way I am able to move traced Mask2Former to any device and perform inference. Any suggestions or alternatives are welcome!

github-actions[bot] commented 3 weeks ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.