Closed Fafa-DL closed 1 year ago
@Fafa-DL Hi, mask2former is not supported in mmdeploy according to the doc. The error suggests that you used torch.eisum
operations in your model forward, you could remove them and try again. One example is like this:
https://github.com/open-mmlab/mmdeploy/blob/bc1b6440cd4e9eb85d02a706bbe263aeec4e66f9/mmdeploy/codebase/mmseg/models/decode_heads/ema_head.py#L29-L45
@RunningLeon Hi, thanks, torch.einsum
exist in mask2former_head.py, Is it modification the same as the example you provided?
https://github.com/open-mmlab/mmsegmentation/blob/916ed2b2e208bf88dfd5180c79baa9883abb8f00/mmseg/models/decode_heads/mask2former_head.py#L131-L163
@RunningLeon I made the following modification
cls_score = F.softmax(mask_cls_results, dim=-1)[..., :-1]
mask_pred = mask_pred_results.sigmoid()
# seg_logits = torch.einsum('bqc, bqhw->bchw', cls_score, mask_pred)
cls_score = cls_score.transpose(1, 2)
cls_score = cls_score.unsqueeze(3)
cls_score = cls_score.unsqueeze(4)
mask_pred = mask_pred.unsqueeze(1)
seg_logits = torch.matmul(cls_score, mask_pred)
but got RuntimeError: CUDA out of memory.
while executing seg_logits = torch.matmul(cls_score, mask_pred)
@RunningLeon I made the following modification
cls_score = F.softmax(mask_cls_results, dim=-1)[..., :-1] mask_pred = mask_pred_results.sigmoid() # seg_logits = torch.einsum('bqc, bqhw->bchw', cls_score, mask_pred) cls_score = cls_score.transpose(1, 2) cls_score = cls_score.unsqueeze(3) cls_score = cls_score.unsqueeze(4) mask_pred = mask_pred.unsqueeze(1) seg_logits = torch.matmul(cls_score, mask_pred)
but got
RuntimeError: CUDA out of memory.
while executingseg_logits = torch.matmul(cls_score, mask_pred)
Maybe you could try with 512x512 or device=cpu
@RunningLeon I made the following modification
cls_score = F.softmax(mask_cls_results, dim=-1)[..., :-1] mask_pred = mask_pred_results.sigmoid() # seg_logits = torch.einsum('bqc, bqhw->bchw', cls_score, mask_pred) cls_score = cls_score.transpose(1, 2) cls_score = cls_score.unsqueeze(3) cls_score = cls_score.unsqueeze(4) mask_pred = mask_pred.unsqueeze(1) seg_logits = torch.matmul(cls_score, mask_pred)
but got
RuntimeError: CUDA out of memory.
while executingseg_logits = torch.matmul(cls_score, mask_pred)
Maybe you could try with 512x512 or device=cpu
In other words, using this method will make the memory larger and the calculation slower?
This issue is marked as stale because it has been marked as invalid or awaiting response for 7 days without any further response. It will be closed in 5 days if the stale label is not removed or if there is no further response.
This issue is closed because it has been stale for 5 days. Please open a new issue if you have similar issues or you have any new updates now.
Checklist
Describe the bug
hi, I want to deploy the weight conversion of mask2former Swin-B (in22k) or mask2former R-50-D32, and the error is as follows, and It is worth mentioning that it is no problem to convert other models such as OCRNet
Reproduction
The executed command is as follows
and execute the following command in
mmsegmentation/
can get the correct resultEnvironment