EleutherAI / gpt-neox

An implementation of model parallel autoregressive transformers on GPUs, based on the Megatron and DeepSpeed libraries
https://www.eleuther.ai/
Apache License 2.0
6.8k stars 985 forks source link

The xformers result can not match with norm attention result #998

Closed guozhiyao closed 1 year ago

guozhiyao commented 1 year ago

I try to modify the huggingface gpt-neox model with xformers. but the result can not match. Could you help me out with it?

Here is my code, which is part of attention in huggingface.

import xformers.ops as xops
import torch
from torch import nn
import pdb

dtype = torch.bfloat16

def _xformers_attn(query, key, value, **kwargs):
    query = query.transpose(1, 2).to(value.dtype)
    key = key.transpose(1, 2).to(value.dtype)
    value = value.transpose(1, 2)

    output = xops.memory_efficient_attention(
        query, key, value, op=xops.MemoryEfficientAttentionFlashAttentionOp,
        attn_bias=xops.LowerTriangularMask(),
        p=0
    )
    matmul_result = output.transpose(1, 2)

    return matmul_result.to(query.dtype), None

def _init_bias(max_positions, device=None):
    bias = torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
            1, 1, max_positions, max_positions
        )
    if device is not None:
        bias = bias.to(device)
    return bias

hidden_size = 5120
num_attention_heads = 40
head_size = hidden_size // num_attention_heads
bias = _init_bias(2048, "cuda")
norm_factor = torch.sqrt(torch.tensor(head_size, dtype=dtype, device="cuda"))
attention_dropout = nn.Dropout(0)
def _attn(query, key, value, attention_mask=None, head_mask=None):
    global bias
    # compute causal mask from causal mask buffer
    batch_size, num_attention_heads, query_length, attn_head_size = query.size()
    key_length = key.size(-2)

    # dynamically increase the causal mask with the key length, if needed.
    if key_length > bias.shape[-1]:
        bias =_init_bias(key_length, device=key.device)
    causal_mask = bias[:, :, key_length - query_length : key_length, :key_length]

    query = query.view(batch_size * num_attention_heads, query_length, attn_head_size)
    key = key.view(batch_size * num_attention_heads, key_length, attn_head_size)
    attn_scores = torch.zeros(
        batch_size * num_attention_heads,
        query_length,
        key_length,
        dtype=query.dtype,
        device=key.device,
    )
    attn_scores = torch.baddbmm(
        attn_scores,
        query,
        key.transpose(1, 2),
        beta=1.0,
        alpha=(torch.tensor(1.0, dtype=norm_factor.dtype, device=norm_factor.device) / norm_factor),
    )
    attn_scores = attn_scores.view(batch_size, num_attention_heads, query_length, key_length)

    mask_value = torch.finfo(attn_scores.dtype).min
    # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
    # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
    mask_value = torch.tensor(mask_value, dtype=attn_scores.dtype).to(attn_scores.device)
    attn_scores = torch.where(causal_mask, attn_scores, mask_value)

    if attention_mask is not None:
        # Apply the attention mask
        attn_scores = attn_scores + attention_mask

    attn_weights = nn.functional.softmax(attn_scores, dim=-1)
    attn_weights = attn_weights.to(value.dtype)

    # Mask heads if we want to
    if head_mask is not None:
        attn_weights = attn_weights * head_mask

    attn_weights = attention_dropout(attn_weights)

    attn_output = torch.matmul(attn_weights, value)
    return attn_output, attn_weights

bs = 1
seq_len = 128
attn_head_size = head_size
q = torch.randn((bs, num_attention_heads, seq_len, attn_head_size)).cuda().to(dtype)
k = torch.randn((bs, num_attention_heads, seq_len, attn_head_size)).cuda().to(dtype)
v = torch.randn((bs, num_attention_heads, seq_len, attn_head_size)).cuda().to(dtype)
out1, _ = _attn(q,k,v)
out2, _ = _xformers_attn(q,k,v)
print(torch.allclose(out1, out2))
pdb.set_trace()

Here is the environment.

Collecting environment information... PyTorch version: 1.13.0 Is debug build: False CUDA used to build PyTorch: 11.6 ROCM used to build PyTorch: N/A

OS: Alibaba Group Enterprise Linux Server 7.2 (Paladin) (x86_64) GCC version: (GCC) 7.5.0 Clang version: Could not collect CMake version: version 3.22.0 Libc version: glibc-2.32

Python version: 3.8.13 (default, Oct 21 2022, 23:50:54) [GCC 11.2.0] (64-bit runtime) Python platform: Linux-5.10.112-005.ali5000.alios7.x86_64-x86_64-with-glibc2.17 Is CUDA available: True CUDA runtime version: 11.3.58 CUDA_MODULE_LOADING set to: LAZY GPU models and configuration: GPU 0: NVIDIA A100-SXM4-80GB Nvidia driver version: 470.154 cuDNN version: Probably one of the following: /usr/lib64/libcudnn.so.8.4.0 /usr/lib64/libcudnn_adv_infer.so.8.4.0 /usr/lib64/libcudnn_adv_train.so.8.4.0 /usr/lib64/libcudnn_cnn_infer.so.8.4.0 /usr/lib64/libcudnn_cnn_train.so.8.4.0 /usr/lib64/libcudnn_ops_infer.so.8.4.0 /usr/lib64/libcudnn_ops_train.so.8.4.0 HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True

Versions of relevant libraries: [pip3] mypy-extensions==0.4.3 [pip3] numpy==1.23.4 [pip3] torch==1.13.0+cu111 [pip3] torchaudio==0.11.0 [pip3] torchvision==0.14.0 [conda] No relevant packages

A matching Triton is not available, some optimizations will not be enabled. Error caught was: module 'triton.language' has no attribute 'constexpr' A matching Triton is not available, some optimizations will not be enabled. Error caught was: module 'triton.language' has no attribute 'constexpr' xFormers 0.0.15.dev+103e863.d20221125 memory_efficient_attention.flshatt: available - requires GPU with compute capability 7.5+ memory_efficient_attention.cutlass: available memory_efficient_attention.small_k: available swiglu.fused.p.cpp: available is_triton_available: False is_functorch_available: False pytorch.version: 1.13.0 pytorch.cuda: available gpu.compute_capability: 8.0 gpu.name: NVIDIA A100-SXM4-80GB

StellaAthena commented 1 year ago

This library was used to train the original model, but your issue seems more appropriate to open on the transformers or xformers library.