huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
134.55k stars 26.91k forks source link

[mask2former] torch.export error for Mask2Former #34390

Open philkuz opened 1 week ago

philkuz commented 1 week ago

System Info

Who can help?

@amyeroberts, @qubvel, @ylacombe

Information

Tasks

Reproduction

import torch
from transformers import Mask2FormerForUniversalSegmentation

model = Mask2FormerForUniversalSegmentation.from_pretrained(
    "facebook/mask2former-swin-base-coco-panoptic", torchscript=True
)

scripted_model = torch.export.export(model, args=(torch.randn(1, 3, 800, 1280),))

which causes

UserError: Could not extract specialized integer from data-dependent expression u0 (unhinted: u0).  (Size-like symbols: none)

Potential framework code culprit (scroll up for full backtrace):
  File "/home/philkuz/.pyenv/versions/3.11.9/envs/gml311/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 2132, in run_node
    return node.target(*args, **kwargs)

For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing

User Stack (most recent call last):
  (snipped, see stack below for prefix)
  File "/home/philkuz/dev/transformers/src/transformers/models/mask2former/modeling_mask2former.py", line 2499, in forward
    outputs = self.model(
  File "/home/philkuz/.pyenv/versions/3.11.9/envs/gml311/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/philkuz/dev/transformers/src/transformers/models/mask2former/modeling_mask2former.py", line 2270, in forward
    pixel_level_module_output = self.pixel_level_module(
  File "/home/philkuz/.pyenv/versions/3.11.9/envs/gml311/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/philkuz/dev/transformers/src/transformers/models/mask2former/modeling_mask2former.py", line 1395, in forward
    decoder_output = self.decoder(backbone_features, output_hidden_states=output_hidden_states)
  File "/home/philkuz/.pyenv/versions/3.11.9/envs/gml311/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/philkuz/dev/transformers/src/transformers/models/mask2former/modeling_mask2former.py", line 1319, in forward
    encoder_outputs = self.encoder(
  File "/home/philkuz/.pyenv/versions/3.11.9/envs/gml311/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/philkuz/dev/transformers/src/transformers/models/mask2former/modeling_mask2former.py", line 1165, in forward
    reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=inputs_embeds.device)
  File "/home/philkuz/dev/transformers/src/transformers/models/mask2former/modeling_mask2former.py", line 1106, in get_reference_points
    torch.linspace(0.5, height - 0.5, height, dtype=valid_ratios.dtype, device=device),

For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#constrain-as-size-example

from user code:
   File "/home/philkuz/dev/transformers/src/transformers/models/mask2former/modeling_mask2former.py", line 2499, in forward
    outputs = self.model(
  File "/home/philkuz/.pyenv/versions/3.11.9/envs/gml311/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/philkuz/dev/transformers/src/transformers/models/mask2former/modeling_mask2former.py", line 2270, in forward
    pixel_level_module_output = self.pixel_level_module(
  File "/home/philkuz/.pyenv/versions/3.11.9/envs/gml311/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/philkuz/dev/transformers/src/transformers/models/mask2former/modeling_mask2former.py", line 1395, in forward
    decoder_output = self.decoder(backbone_features, output_hidden_states=output_hidden_states)
  File "/home/philkuz/.pyenv/versions/3.11.9/envs/gml311/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/philkuz/dev/transformers/src/transformers/models/mask2former/modeling_mask2former.py", line 1319, in forward
    encoder_outputs = self.encoder(
  File "/home/philkuz/.pyenv/versions/3.11.9/envs/gml311/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/philkuz/dev/transformers/src/transformers/models/mask2former/modeling_mask2former.py", line 1165, in forward
    reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=inputs_embeds.device)
  File "/home/philkuz/dev/transformers/src/transformers/models/mask2former/modeling_mask2former.py", line 1106, in get_reference_points
    torch.linspace(0.5, height - 0.5, height, dtype=valid_ratios.dtype, device=device),

Expected behavior

torch.export works for this model.

philkuz commented 1 week ago

The crux of the issue is the erasure of shape info in a few places, such as: https://github.com/huggingface/transformers/blob/3d99f1746e0d667cbec9e69b4ec11289c4752630/src/transformers/models/mask2former/modeling_mask2former.py#L1307

As well as a downstream issue simliar to the issue raised in https://github.com/huggingface/transformers/issues/34022 and fixed in https://github.com/huggingface/transformers/pull/34023

I've managed to narrow down the minimum number of changes necessary to enable this model in my local fork and will put out a PR shortly.