huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
135.72k stars 27.17k forks source link

Switching Mask2Former Backbones #26403

Closed alen-smajic closed 1 year ago

alen-smajic commented 1 year ago

System Info

Who can help?

@amyeroberts

Information

Tasks

Reproduction

I would like to combine the DINOv2 backbone model with the Mask2Former model for semantic segmentation. Even though the official documentation states that Mask2Former only works with a Swin Transformer backbone, I stumbled upon this issue #24244.

In PR #24532 multi-backbone support has been implemented by @amyeroberts , and some exemplary code has been provided. So far the model instantiation works, however when I try to infer any data into the model I get an error:

Script to reproduce:

from PIL import Image
import requests
from transformers import (
    Mask2FormerConfig, 
    Mask2FormerModel, 
    Mask2FormerImageProcessor,
    FocalNetConfig,
    Dinov2Config
)

backbone_config = FocalNetConfig(out_indices=(-2, -1))  # This is the official example from PR #24532
#backbone_config = Dinov2Config(out_indices=(-2, -1))  # This doesn't work either
mask2former_config = Mask2FormerConfig(backbone_config=backbone_config)
model = Mask2FormerModel(mask2former_config)

processor = Mask2FormerImageProcessor(size=(224, 224))

url = (
    "https://huggingface.co/datasets/hf-internal-testing/fixtures_ade20k/resolve/main/ADE_val_00000001.jpg"
)
image = Image.open(requests.get(url, stream=True).raw)
inputs = processor(image, return_tensors="pt")

output = model(**inputs)

The error I get:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
/home/asl2hi/DST/fm_semseg/debug_model.ipynb Cell 5 line 2
     [19](vscode-notebook-cell://ssh-remote%2B7b22686f73744e616d65223a2248696c6465736865696d2d436c7573746572227d/home/asl2hi/DST/fm_semseg/debug_model.ipynb#X12sdnNjb2RlLXJlbW90ZQ%3D%3D?line=18) image = Image.open(requests.get(url, stream=True).raw)
     [20](vscode-notebook-cell://ssh-remote%2B7b22686f73744e616d65223a2248696c6465736865696d2d436c7573746572227d/home/asl2hi/DST/fm_semseg/debug_model.ipynb#X12sdnNjb2RlLXJlbW90ZQ%3D%3D?line=19) inputs = processor(image, return_tensors="pt")
---> [22](vscode-notebook-cell://ssh-remote%2B7b22686f73744e616d65223a2248696c6465736865696d2d436c7573746572227d/home/asl2hi/DST/fm_semseg/debug_model.ipynb#X12sdnNjb2RlLXJlbW90ZQ%3D%3D?line=21) output = model(**inputs)
>
File [~/DST/venv/lib/python3.8/site-packages/torch/nn/modules/module.py:1501](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a2248696c6465736865696d2d436c7573746572227d.vscode-resource.vscode-cdn.net/home/asl2hi/DST/fm_semseg/~/DST/venv/lib/python3.8/site-packages/torch/nn/modules/module.py:1501), in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File [~/DST/venv/lib/python3.8/site-packages/transformers/models/mask2former/modeling_mask2former.py:2271](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a2248696c6465736865696d2d436c7573746572227d.vscode-resource.vscode-cdn.net/home/asl2hi/DST/fm_semseg/~/DST/venv/lib/python3.8/site-packages/transformers/models/mask2former/modeling_mask2former.py:2271), in Mask2FormerModel.forward(self, pixel_values, pixel_mask, output_hidden_states, output_attentions, return_dict)
   2268 if pixel_mask is None:
   2269     pixel_mask = torch.ones((batch_size, height, width), device=pixel_values.device)
-> 2271 pixel_level_module_output = self.pixel_level_module(
   2272     pixel_values=pixel_values, output_hidden_states=output_hidden_states
   2273 )
   2275 transformer_module_output = self.transformer_module(
   2276     multi_scale_features=pixel_level_module_output.decoder_hidden_states,
   2277     mask_features=pixel_level_module_output.decoder_last_hidden_state,
   2278     output_hidden_states=True,
   2279     output_attentions=output_attentions,
   2280 )
   2282 encoder_hidden_states = None

File [~/DST/venv/lib/python3.8/site-packages/torch/nn/modules/module.py:1501](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a2248696c6465736865696d2d436c7573746572227d.vscode-resource.vscode-cdn.net/home/asl2hi/DST/fm_semseg/~/DST/venv/lib/python3.8/site-packages/torch/nn/modules/module.py:1501), in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File [~/DST/venv/lib/python3.8/site-packages/transformers/models/mask2former/modeling_mask2former.py:1396](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a2248696c6465736865696d2d436c7573746572227d.vscode-resource.vscode-cdn.net/home/asl2hi/DST/fm_semseg/~/DST/venv/lib/python3.8/site-packages/transformers/models/mask2former/modeling_mask2former.py:1396), in Mask2FormerPixelLevelModule.forward(self, pixel_values, output_hidden_states)
   1394 def forward(self, pixel_values: Tensor, output_hidden_states: bool = False) -> Mask2FormerPixelLevelModuleOutput:
   1395     backbone_features = self.encoder(pixel_values).feature_maps
-> 1396     decoder_output = self.decoder(backbone_features, output_hidden_states=output_hidden_states)
   1398     return Mask2FormerPixelLevelModuleOutput(
   1399         encoder_last_hidden_state=backbone_features[-1],
   1400         encoder_hidden_states=tuple(backbone_features) if output_hidden_states else None,
   1401         decoder_last_hidden_state=decoder_output.mask_features,
   1402         decoder_hidden_states=decoder_output.multi_scale_features,
   1403     )

File [~/DST/venv/lib/python3.8/site-packages/torch/nn/modules/module.py:1501](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a2248696c6465736865696d2d436c7573746572227d.vscode-resource.vscode-cdn.net/home/asl2hi/DST/fm_semseg/~/DST/venv/lib/python3.8/site-packages/torch/nn/modules/module.py:1501), in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File [~/DST/venv/lib/python3.8/site-packages/transformers/models/mask2former/modeling_mask2former.py:1320](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a2248696c6465736865696d2d436c7573746572227d.vscode-resource.vscode-cdn.net/home/asl2hi/DST/fm_semseg/~/DST/venv/lib/python3.8/site-packages/transformers/models/mask2former/modeling_mask2former.py:1320), in Mask2FormerPixelDecoder.forward(self, features, encoder_outputs, output_attentions, output_hidden_states, return_dict)
   1318 # Send input_embeds_flat + masks_flat + level_pos_embed_flat (backbone + proj layer output) through encoder
   1319 if encoder_outputs is None:
-> 1320     encoder_outputs = self.encoder(
   1321         inputs_embeds=input_embeds_flat,
   1322         attention_mask=masks_flat,
   1323         position_embeddings=level_pos_embed_flat,
   1324         spatial_shapes=spatial_shapes,
   1325         level_start_index=level_start_index,
   1326         valid_ratios=valid_ratios,
   1327         output_attentions=output_attentions,
   1328         output_hidden_states=output_hidden_states,
   1329         return_dict=return_dict,
   1330     )
   1332 last_hidden_state = encoder_outputs.last_hidden_state
   1333 batch_size = last_hidden_state.shape[0]

File [~/DST/venv/lib/python3.8/site-packages/torch/nn/modules/module.py:1501](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a2248696c6465736865696d2d436c7573746572227d.vscode-resource.vscode-cdn.net/home/asl2hi/DST/fm_semseg/~/DST/venv/lib/python3.8/site-packages/torch/nn/modules/module.py:1501), in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File [~/DST/venv/lib/python3.8/site-packages/transformers/models/mask2former/modeling_mask2former.py:1175](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a2248696c6465736865696d2d436c7573746572227d.vscode-resource.vscode-cdn.net/home/asl2hi/DST/fm_semseg/~/DST/venv/lib/python3.8/site-packages/transformers/models/mask2former/modeling_mask2former.py:1175), in Mask2FormerPixelDecoderEncoderOnly.forward(self, inputs_embeds, attention_mask, position_embeddings, spatial_shapes, level_start_index, valid_ratios, output_attentions, output_hidden_states, return_dict)
   1172 if output_hidden_states:
   1173     all_hidden_states += (hidden_states.transpose(1, 0),)
-> 1175 layer_outputs = encoder_layer(
   1176     hidden_states,
   1177     attention_mask,
   1178     position_embeddings=position_embeddings,
   1179     reference_points=reference_points,
   1180     spatial_shapes=spatial_shapes,
   1181     level_start_index=level_start_index,
   1182     output_attentions=output_attentions,
   1183 )
   1185 hidden_states = layer_outputs[0]
   1187 if output_attentions:

File [~/DST/venv/lib/python3.8/site-packages/torch/nn/modules/module.py:1501](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a2248696c6465736865696d2d436c7573746572227d.vscode-resource.vscode-cdn.net/home/asl2hi/DST/fm_semseg/~/DST/venv/lib/python3.8/site-packages/torch/nn/modules/module.py:1501), in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File [~/DST/venv/lib/python3.8/site-packages/transformers/models/mask2former/modeling_mask2former.py:1030](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a2248696c6465736865696d2d436c7573746572227d.vscode-resource.vscode-cdn.net/home/asl2hi/DST/fm_semseg/~/DST/venv/lib/python3.8/site-packages/transformers/models/mask2former/modeling_mask2former.py:1030), in Mask2FormerPixelDecoderEncoderLayer.forward(self, hidden_states, attention_mask, position_embeddings, reference_points, spatial_shapes, level_start_index, output_attentions)
   1027 residual = hidden_states
   1029 # Apply Multi-scale Deformable Attention Module on the multi-scale feature maps.
-> 1030 hidden_states, attn_weights = self.self_attn(
   1031     hidden_states=hidden_states,
   1032     attention_mask=attention_mask,
   1033     encoder_hidden_states=hidden_states,
   1034     encoder_attention_mask=attention_mask,
   1035     position_embeddings=position_embeddings,
   1036     reference_points=reference_points,
   1037     spatial_shapes=spatial_shapes,
   1038     level_start_index=level_start_index,
   1039     output_attentions=output_attentions,
   1040 )
   1042 hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
   1043 hidden_states = residual + hidden_states

File [~/DST/venv/lib/python3.8/site-packages/torch/nn/modules/module.py:1501](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a2248696c6465736865696d2d436c7573746572227d.vscode-resource.vscode-cdn.net/home/asl2hi/DST/fm_semseg/~/DST/venv/lib/python3.8/site-packages/torch/nn/modules/module.py:1501), in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File [~/DST/venv/lib/python3.8/site-packages/transformers/models/mask2former/modeling_mask2former.py:964](https://vscode-remote+ssh-002dremote-002b7b22686f73744e616d65223a2248696c6465736865696d2d436c7573746572227d.vscode-resource.vscode-cdn.net/home/asl2hi/DST/fm_semseg/~/DST/venv/lib/python3.8/site-packages/transformers/models/mask2former/modeling_mask2former.py:964), in Mask2FormerPixelDecoderEncoderMultiscaleDeformableAttention.forward(self, hidden_states, attention_mask, encoder_hidden_states, encoder_attention_mask, position_embeddings, reference_points, spatial_shapes, level_start_index, output_attentions)
    960 if reference_points.shape[-1] == 2:
    961     offset_normalizer = torch.stack([spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
    962     sampling_locations = (
    963         reference_points[:, :, None, :, None, :]
--> 964         + sampling_offsets / offset_normalizer[None, None, None, :, None, :]
    965     )
    966 elif reference_points.shape[-1] == 4:
    967     sampling_locations = (
    968         reference_points[:, :, None, :, None, :2]
    969         + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5
    970     )

RuntimeError: The size of tensor a (3) must match the size of tensor b (2) at non-singleton dimension 3

Expected behavior

If working properly the code from above should output the model predictions (of class Mask2FormerModelOutput), which where produced by running the input image trough the new backbone and then forwarding the feature maps to the Mask2Former model.

LysandreJik commented 1 year ago

cc @rafaelpadilla would love it if you could take a look!

rafaelpadilla commented 1 year ago

Hi @alen-smajic ,

Thank you for reporting this issue. :)

The code you showed is not working because of out_indices=(-2, -1). Try to replace it by:

backbone_config = FocalNetConfig(out_indices=(1,2,3,4))

For the backbone, dinov2 model is not supported. These are the supported ones: BitConfig, ConvNextConfig, ConvNextV2Config, DinatConfig, FocalNetConfig, MaskFormerSwinConfig, NatConfig, ResNetConfig, SwinConfig, TimmBackboneConfig.

alen-smajic commented 1 year ago

Hi @rafaelpadilla ,

thanks for the quick help. You are totally right, the out_indices attribute was not correctly set.

I have in fact managed to attach a dinov2 backbone on the Mask2Former model and it seems to work :)

import requests

from PIL import Image
import torch
from transformers import (
    AutoImageProcessor,
    Dinov2Config,
    Dinov2Model,
    Mask2FormerConfig,
    Mask2FormerForUniversalSegmentation
)

# Store Dinov2 weights locally 
dinov2_backbone_model = Dinov2Model.from_pretrained("facebook/dinov2-base", out_indices=[6, 8, 10, 12])
torch.save(dinov2_backbone_model.state_dict(), "dinov2-base.pth")

# Create Mask2Former config with Dinov2 backbone
image_processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-tiny-cityscapes-semantic")
model_config = Mask2FormerConfig.from_pretrained("facebook/mask2former-swin-tiny-cityscapes-semantic")
model_config.backbone_config = Dinov2Config.from_pretrained("facebook/dinov2-base", out_indices=(6, 8, 10, 12))

# Instantiate Mask2Former model with Dinov2 backbone (random weights)
model = Mask2FormerForUniversalSegmentation(model_config)

# Load Dinov2 weights into Mask2Former backbone
dinov2_backbone = model.model.pixel_level_module.encoder
dinov2_backbone.load_state_dict(torch.load("dinov2-base.pth"))

image_processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-tiny-cityscapes-semantic")
url = (
    "https://huggingface.co/datasets/hf-internal-testing/fixtures_ade20k/resolve/main/ADE_val_00000001.jpg"
)
image = Image.open(requests.get(url, stream=True).raw)
inputs = image_processor(image, return_tensors="pt")

with torch.no_grad():
    outputs = model(**inputs)

results = image_processor.post_process_semantic_segmentation(outputs, target_sizes=[image.size[::-1]])
rafaelpadilla commented 1 year ago

Hi @alen-smajic ,

Glad to see the problem was solved :)

I will close this issue for now. Feel free to re-open it in case you encounter any related concerns in the future.

matteot11 commented 1 year ago

Hi @alen-smajic, thanks for the snipped, I managed to use Dinov2 as a backbone for Mask2Former. Did you try to finetune it on your own data? I am experiencing very low performance. Could the reason be that the authors of Dinov2 used ViT-Adapter? Every additional suggestion would be very appreciated :)

morrisalp commented 8 months ago

Hi @alen-smajic, thanks for the snipped, I managed to use Dinov2 as a backbone for Mask2Former. Did you try to finetune it on your own data? I am experiencing very low performance. Could the reason be that the authors of Dinov2 used ViT-Adapter? Every additional suggestion would be very appreciated :)

From their notebook here I see they use ViT-Adapter (see the model summary which has ViTAdapter).

antopost commented 5 months ago

Interested to know if anyone has implemented Mask2Former with Dinov2 and ViT-Adapter with Huffingface modules rather than mmseg.