Dao-AILab / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
13.37k stars 1.22k forks source link

misaligned address when using DropoutAddRMSNorm #289

Open KimmiShi opened 1 year ago

KimmiShi commented 1 year ago

Hi, I have a program, which produces a tensor that will be computed by DropoutAddRMSNorm. I got a error message:

CUDA Error: misaligned address .../flash-attention/csrc/layer_norm/ln_fwd_kernels.cuh 236
-> q = self.q_norm(q.transpose(1, 2).flatten(-2, -1))[0].view(B_, N_, H_, D_).transpose(1, 2)                    
p q.shape                                                                                                        
(Pdb) torch.Size([2, 25, 257, 48])                                                                               
p q.dtype                                                                                                        
(Pdb) torch.bfloat16                                                                                             
p q                                                                                                              
(Pdb) tensor([[[[-1.2506e-12, -1.5010e-13,  3.3396e-13,  ...,  5.8265e-13,                                       
           -1.3589e-13, -5.2580e-13],                                                                            
          [-1.2434e-12, -1.5454e-13,  3.3218e-13,  ...,  5.7909e-13,                                             
           -1.4211e-13, -5.3646e-13],                                                                            
          [-1.2506e-12, -1.4833e-13,  3.3573e-13,  ...,  5.8620e-13,                                             
           -1.4033e-13, -5.3291e-13],                                                                            
          ...,                           

then tensor q look ok, and then if I try to compute

self.q_norm(q.transpose(1, 2).flatten(-2, -1))[0].view(B_, N_, H_, D_)

I get: (Pdb) *** RuntimeError: CUDA error: misaligned address

How to check if the tensor is misaligned?

KimmiShi commented 1 year ago

I runned with compute-sanitizer and got:

========= Invalid __global__ read of size 16 bytes
=========     at 0x1e0 in void layer_norm::ln_fwd_kernel<layer_norm::Kernel_traits<__nv_bfloat16, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, float, unsigned int, (unsigned int)1280, (unsigned int)1, (unsigned int)4, (unsigned int)1, (unsigned int)16, layer_norm::Kernel_traits_base<(unsigned int)1280, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, float, unsigned int, (unsigned int)128>>, (bool)0, (bool)0, (bool)0, (bool)0>(layer_norm::FwdParams)
=========     by thread (105,0,0) in block (26,0,0)
=========     Address 0x7f02facf1832 is misaligned

Is this related to ln_fwd_kernel impl?

ptrblck commented 1 year ago

Originally reported here.

tridao commented 1 year ago

@KimmiShi Can you post a short script to reproduce the error? Sth like

# Construct DropoutAddRMSNorm module
# Generate q
# Pass q to the module, get error

Right now I can't reproduce the error.

KimmiShi commented 1 year ago

@KimmiShi Can you post a short script to reproduce the error? Sth like

Thanks @tridao , I cannot provide a short one, I will provide a larger one later.

I am curious that does DropoutAddRMSNorm require the inputs from user to be 16B aligned? or this error is produce inside this op?

tridao commented 1 year ago

It requires the last dimension to be multiple of 8, as mentioned in the README. We do call .contiguous() and check that dimension is divisible by 8. Maybe there's some edge cases where something with last dimension not divisible by 8 but gets past this check.

tridao commented 1 year ago

If you can print out more info (shape, stride, dtype) of the input to DropoutAddRMSNorm that would also help me reproduce the error. e.g., before self.q_norm:

input = q.transpose(1, 2).flatten(-2, -1)
print(input.device, input.dtype, input.shape, input.stride())
KimmiShi commented 1 year ago

Thanks, I runned with code:

        if self.qk_normalization:
            q, k, v = qkv.unbind(2)
            if self.qk_normalization_head_merged:
                if self.fuse_dal:
                    input=q.flatten(-2, -1)
                    print("q:",input.device, input.dtype, input.shape, input.stride(), flush=True)
                    q = self.q_norm(input)[0].view(q.shape)
                    torch.cuda.synchronize()

                    input=k.flatten(-2, -1)
                    print("k:",input.device, input.dtype, input.shape, input.stride(), flush=True)

                    k = self.k_norm(input)[0].view(k.shape)
                    torch.cuda.synchronize()

got two lines of print before error:

q: cuda:0 torch.bfloat16 torch.Size([2, 257, 1200]) (925200, 3600, 1)
k: cuda:0 torch.bfloat16 torch.Size([2, 257, 1200]) (925200, 3600, 1)
KimmiShi commented 1 year ago

I tried with a code that runs well, It has the same print for these two tensors:

q: cuda:0 torch.bfloat16 torch.Size([2, 257, 1200]) (925200, 3600, 1)
k: cuda:0 torch.bfloat16 torch.Size([2, 257, 1200]) (925200, 3600, 1)
q: cuda:0 torch.bfloat16 torch.Size([2, 257, 1200]) (925200, 3600, 1)
k: cuda:0 torch.bfloat16 torch.Size([2, 257, 1200]) (925200, 3600, 1)
q: cuda:0 torch.bfloat16 torch.Size([2, 257, 1200]) (925200, 3600, 1)
k: cuda:0 torch.bfloat16 torch.Size([2, 257, 1200]) (925200, 3600, 1)
q: cuda:0 torch.bfloat16 torch.Size([2, 257, 1200]) (925200, 3600, 1)
k: cuda:0 torch.bfloat16 torch.Size([2, 257, 1200]) (925200, 3600, 1)
q: cuda:0 torch.bfloat16 torch.Size([2, 257, 1200]) (925200, 3600, 1)
KimmiShi commented 1 year ago

To reproduce, need to use this deepspeed: https://github.com/KimmiShi/DeepSpeed/tree/bug, clone and pip install this repo.

And flash-attn with layer_norm extention is also needed. I am using pt1.13+cu117. I tried pt2.0 it also reproduces.

And run with torch distributed, I am running with slurm: srun -p partition -n1 -N1 --gres=gpu:1 compute-sanitizer python deepspeed/debug_err.py; If not using slurm, need to modify line 599, setup_distributed_slurm to setup torch distributed corretly.

tridao commented 1 year ago

Thanks for the repro script, I've narrowed it down to a memory alignment problem. We expect all input tensors to be aligned to 16 bytes (in order to use vectorized load for max performance). Right now the error is due to the weight tensor inside RMSNorm not being aligned to 16 bytes.

I thought it was sufficient to ensure that the last dimensions are all divisible by 8 and we call .contiguous(), and that would lead to 16 byte alignment. Turns out that's not the case. I'll need to figure out a way to ensure memory alignement.

tridao commented 1 year ago

I don't know a reliable way to get 16 bytes alignment, but I've posted a question to Pytorch forum.

tridao commented 1 year ago

I pushed a commit to (hopefully) make sure that memory addresses are aligned by 16 bytes by cloning the inputs.

KimmiShi commented 1 year ago

Thanks! @tridao , the layer_norm is working now.

However, I came across another error when running debug_err.py. In debug_err.py the fused_mlp is used by default, here is the error print:

envs/pt13/lib/python3.9/site-packages/flash_attn/ops/fused_dense.py", line 257, in forward
     output1, *rest = fused_dense_cuda.linear_act_forward(
RuntimeError: linear_act_forward failed

I didn't get extra info using compute-santinizer.

When I set use_fused_mlp=False, there is no error.

Can you reproduce this error?

tridao commented 1 year ago

Yes, I can reproduce it. I don't have the bandwidth right now to debug it. I'm not familiar with DeepSpeed, I suspect it puts all parameters in a buffer and that can cause alignment issue. For now I recommend not using fused_mlp with DeepSpeed.

KimmiShi commented 1 year ago

ok, thanks a lot. I also suspect that this error is related to DeepSpeed, I will look into this issuse.

KimmiShi commented 1 year ago

I found out that this alignment issue might be related to deepspeed flattening all the parameters of a model into a contiguous memory. This can be triggered when using bf16 or zero ddp. And when using fp16 with user defined adam opt (not fusedadam), the paramters are not changed, and this error would not occure.

example: bf16 optimizer that flattens parameters: https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/runtime/bf16_optimizer.py#L105

@ptrblck @tridao FYI, thanks a lot.

SCZwangxiao commented 3 months ago

I found out that this alignment issue might be related to deepspeed flattening all the parameters of a model into a contiguous memory. This can be triggered when using bf16 or zero ddp. And when using fp16 with user defined adam opt (not fusedadam), the paramters are not changed, and this error would not occure.

example: bf16 optimizer that flattens parameters: https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/runtime/bf16_optimizer.py#L105

@ptrblck @tridao FYI, thanks a lot.

So what should we do if we want to keep bf16 training?

KimmiShi commented 3 months ago

I found out that this alignment issue might be related to deepspeed flattening all the parameters of a model into a contiguous memory. This can be triggered when using bf16 or zero ddp. And when using fp16 with user defined adam opt (not fusedadam), the paramters are not changed, and this error would not occure. example: bf16 optimizer that flattens parameters: https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/runtime/bf16_optimizer.py#L105 @ptrblck @tridao FYI, thanks a lot.

So what should we do if we want to keep bf16 training?

@SCZwangxiao as @tridao suggested: For now I recommend not using fused_mlp with DeepSpeed.