open-mmlab / mmcv

OpenMMLab Computer Vision Foundation
https://mmcv.readthedocs.io/en/latest/
Apache License 2.0
5.83k stars 1.63k forks source link

[Bug] ms_deform_attn_forward_cuda do not support BF16 #2878

Open AkideLiu opened 1 year ago

AkideLiu commented 1 year ago

Prerequisite

Environment

OrderedDict([('sys.platform', 'linux'), ('Python', '3.10.12 (main, Jul 5 2023, 18:54:27) [GCC 11.2.0]'), ('CUDA available', True), ('numpy_random_seed', 2147483648), ('GPU 0', 'NVIDIA GeForce RTX 4090'), ('CUDA_HOME', '/usr/local/cuda'), ('NVCC', 'Cuda compilation tools, release 12.2, V12.2.91'), ('GCC', 'gcc (Ubuntu 11.3.0-1ubuntu1~22.04.1) 11.3.0'), ('PyTorch', '2.0.1'), ('PyTorch compiling details', 'PyTorch built with:\n - GCC 9.3\n - C++ Version: 201703\n - Intel(R) oneAPI Math Kernel Library Version 2023.1-Product Build 20230303 for Intel(R) 64 architecture applications\n - Intel(R) MKL-DNN v2.7.3 (Git Hash 6dbeffbae1f23cbbeae17adb7b5b13f1f37c080e)\n - OpenMP 201511 (a.k.a. OpenMP 4.5)\n - LAPACK is enabled (usually provided by MKL)\n - NNPACK is enabled\n - CPU capability usage: AVX2\n - CUDA Runtime 11.8\n - NVCC architecture flags: -gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_61,code=sm_61;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86;-gencode;arch=compute_90,code=sm_90;-gencode;arch=compute_37,code=compute_37\n - CuDNN 8.7\n - Magma 2.6.1\n - Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=11.8, CUDNN_VERSION=8.7.0, CXX_COMPILER=/opt/rh/devtoolset-9/root/usr/bin/c++, CXX_FLAGS= -D_GLIBCXX_USE_CXX11_ABI=0 -fabi-version=11 -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOROCTRACER -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Werror=bool-operation -Wnarrowing -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wunused-local-typedefs -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Werror=cast-function-type -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_DISABLE_GPU_ASSERTS=ON, TORCH_VERSION=2.0.1, USE_CUDA=ON, USE_CUDNN=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=ON, USE_OPENMP=ON, USE_ROCM=OFF, \n'), ('TorchVision', '0.15.2'), ('OpenCV', '4.8.0'), ('MMEngine', '0.8.2'), ('MMCV', '2.0.1'), ('MMCV Compiler', 'GCC 9.3'), ('MMCV CUDA Compiler', '11.8')])

Reproduces the problem - code sample

Add the following sections to Mask2Former Config :

[mask2former/mask2former_swin-b-in22k-384x384-pre_8xb2-160k_ade20k-640x640.py](mmseg::mask2former/mask2former_swin-b-in22k-384x384-pre_8xb2-160k_ade20k-640x640.py)

optim_wrapper = dict(
    _delete_=True,
    type='AmpOptimWrapper',
    dtype='bfloat16',
    optimizer=optimizer,
    clip_grad=dict(max_norm=0.01, norm_type=2),
    paramwise_cfg=dict(
        custom_keys=custom_keys,
        norm_decay_mult=0.0))

Reproduces the problem - command or script

tools/train.py configs/mask2former/mask2former_swin-b-in22k-384x384-pre_8xb2-160k_ade20k-640x640.py --amp

Reproduces the problem - error message

RuntimeError : " ms_deform_attn_forward_cuda " not implemented for ' BFloat16 '

Additional information

Do we consider adding a conditional check once the BFloat16 is not implemented with CUDA ops, then fallback to native PyTorch implementation?

AkideLiu commented 1 year ago

I have built a temporary fix but it is ugly and complicated :

import copy

import math
import warnings
from typing import Optional, no_type_check

import mmengine
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Tuple, Union

from mmdet.models.task_modules import MlvlPointGenerator
from mmdet.utils import ConfigType, OptMultiConfig
from mmcv.cnn import ConvModule
from mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttnFunction, multi_scale_deformable_attn_pytorch
from mmdet.models import MSDeformAttnPixelDecoder, Mask2FormerTransformerEncoder, DeformableDetrTransformerEncoderLayer, \
    SinePositionalEncoding
from mmengine.model import BaseModule, constant_init, xavier_init
from mmengine.registry import MODELS
from mmengine.utils import deprecated_api_warning
from torch.autograd.function import Function, once_differentiable

from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
from mmcv.ops import MultiScaleDeformableAttention
from mmengine.registry import MODELS
from mmdet.registry import MODELS as MMDET_MODELS
from torch.nn import ModuleList
from mmcv.cnn import Conv2d, ConvModule

@MODELS.register_module()
class TorchMultiScaleDeformableAttention(MultiScaleDeformableAttention):
    def __init__(self, **kwargs):
        # cfg = copy.deepcopy(kwargs)
        self.force_not_using_cuda_ops = force_not_using_cuda_ops
        super().__init__(**kwargs)

    @no_type_check
    @deprecated_api_warning({'residual': 'identity'},
                            cls_name='MultiScaleDeformableAttention')
    def forward(self,
                query: torch.Tensor,
                key: Optional[torch.Tensor] = None,
                value: Optional[torch.Tensor] = None,
                identity: Optional[torch.Tensor] = None,
                query_pos: Optional[torch.Tensor] = None,
                key_padding_mask: Optional[torch.Tensor] = None,
                reference_points: Optional[torch.Tensor] = None,
                spatial_shapes: Optional[torch.Tensor] = None,
                level_start_index: Optional[torch.Tensor] = None,
                **kwargs) -> torch.Tensor:
        """Forward Function of MultiScaleDeformAttention.

        Args:
            query (torch.Tensor): Query of Transformer with shape
                (num_query, bs, embed_dims).
            key (torch.Tensor): The key tensor with shape
                `(num_key, bs, embed_dims)`.
            value (torch.Tensor): The value tensor with shape
                `(num_key, bs, embed_dims)`.
            identity (torch.Tensor): The tensor used for addition, with the
                same shape as `query`. Default None. If None,
                `query` will be used.
            query_pos (torch.Tensor): The positional encoding for `query`.
                Default: None.
            key_padding_mask (torch.Tensor): ByteTensor for `query`, with
                shape [bs, num_key].
            reference_points (torch.Tensor):  The normalized reference
                points with shape (bs, num_query, num_levels, 2),
                all elements is range in [0, 1], top-left (0,0),
                bottom-right (1, 1), including padding area.
                or (N, Length_{query}, num_levels, 4), add
                additional two dimensions is (w, h) to
                form reference boxes.
            spatial_shapes (torch.Tensor): Spatial shape of features in
                different levels. With shape (num_levels, 2),
                last dimension represents (h, w).
            level_start_index (torch.Tensor): The start index of each level.
                A tensor has shape ``(num_levels, )`` and can be represented
                as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].

        Returns:
            torch.Tensor: forwarded results with shape
            [num_query, bs, embed_dims].
        """

        if value is None:
            value = query

        if identity is None:
            identity = query
        if query_pos is not None:
            query = query + query_pos
        if not self.batch_first:
            # change to (bs, num_query ,embed_dims)
            query = query.permute(1, 0, 2)
            value = value.permute(1, 0, 2)

        bs, num_query, _ = query.shape
        bs, num_value, _ = value.shape
        assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_value

        value = self.value_proj(value)
        if key_padding_mask is not None:
            value = value.masked_fill(key_padding_mask[..., None], 0.0)
        value = value.view(bs, num_value, self.num_heads, -1)
        sampling_offsets = self.sampling_offsets(query).view(
            bs, num_query, self.num_heads, self.num_levels, self.num_points, 2)
        attention_weights = self.attention_weights(query).view(
            bs, num_query, self.num_heads, self.num_levels * self.num_points)
        attention_weights = attention_weights.softmax(-1)

        attention_weights = attention_weights.view(bs, num_query,
                                                   self.num_heads,
                                                   self.num_levels,
                                                   self.num_points)
        if reference_points.shape[-1] == 2:
            offset_normalizer = torch.stack(
                [spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
            sampling_locations = reference_points[:, :, None, :, None, :] \
                                 + sampling_offsets \
                                 / offset_normalizer[None, None, None, :, None, :]
        elif reference_points.shape[-1] == 4:
            sampling_locations = reference_points[:, :, None, :, None, :2] \
                                 + sampling_offsets / self.num_points \
                                 * reference_points[:, :, None, :, None, 2:] \
                                 * 0.5
        else:
            raise ValueError(
                f'Last dim of reference_points must be'
                f' 2 or 4, but get {reference_points.shape[-1]} instead.')
        if ((IS_CUDA_AVAILABLE and value.is_cuda)
            or (IS_MLU_AVAILABLE and value.is_mlu)) and not self.force_not_using_cuda_ops:
            output = MultiScaleDeformableAttnFunction.apply(
                value, spatial_shapes, level_start_index, sampling_locations,
                attention_weights, self.im2col_step)
        else:
            output = multi_scale_deformable_attn_pytorch(
                value, spatial_shapes, sampling_locations, attention_weights)

        output = self.output_proj(output)

        if not self.batch_first:
            # (num_query, bs ,embed_dims)
            output = output.permute(1, 0, 2)

        return self.dropout(output) + identity

class TorchDeformableDetrTransformerEncoderLayer(DeformableDetrTransformerEncoderLayer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # self._init_layers()

    def _init_layers(self) -> None:
        """Initialize self_attn, ffn, and norms."""
        super()._init_layers()
        self.self_attn = TorchMultiScaleDeformableAttention(**self.self_attn_cfg)

class TorchMask2FormerTransformerEncoder(Mask2FormerTransformerEncoder):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._init_layers()

    def _init_layers(self) -> None:
        """Initialize encoder layers."""
        self.layers = ModuleList([
            TorchDeformableDetrTransformerEncoderLayer(**self.layer_cfg)
            for _ in range(self.num_layers)
        ])
        self.embed_dims = self.layers[0].embed_dims
    # pass

@MMDET_MODELS.register_module()
class TorchMSDeformAttnPixelDecoder(MSDeformAttnPixelDecoder):
    def __init__(self, *args, **kwargs):
        encoder_cfg = kwargs.get('encoder', None)

        global force_not_using_cuda_ops
        force_not_using_cuda_ops = kwargs['encoder']['layer_cfg']['self_attn_cfg'].pop('force_not_using_cuda_ops', None)

        super().__init__(*args, **kwargs)
        self.encoder = TorchMask2FormerTransformerEncoder(**encoder_cfg)
        self.init_weights()