Open KimmiShi opened 1 year ago
I runned with compute-sanitizer and got:
========= Invalid __global__ read of size 16 bytes
========= at 0x1e0 in void layer_norm::ln_fwd_kernel<layer_norm::Kernel_traits<__nv_bfloat16, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, float, unsigned int, (unsigned int)1280, (unsigned int)1, (unsigned int)4, (unsigned int)1, (unsigned int)16, layer_norm::Kernel_traits_base<(unsigned int)1280, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, float, unsigned int, (unsigned int)128>>, (bool)0, (bool)0, (bool)0, (bool)0>(layer_norm::FwdParams)
========= by thread (105,0,0) in block (26,0,0)
========= Address 0x7f02facf1832 is misaligned
Is this related to ln_fwd_kernel impl?
@KimmiShi Can you post a short script to reproduce the error? Sth like
# Construct DropoutAddRMSNorm module
# Generate q
# Pass q to the module, get error
Right now I can't reproduce the error.
@KimmiShi Can you post a short script to reproduce the error? Sth like
Thanks @tridao , I cannot provide a short one, I will provide a larger one later.
I am curious that does DropoutAddRMSNorm
require the inputs from user to be 16B aligned? or this error is produce inside this op?
It requires the last dimension to be multiple of 8, as mentioned in the README. We do call .contiguous()
and check that dimension is divisible by 8. Maybe there's some edge cases where something with last dimension not divisible by 8 but gets past this check.
If you can print out more info (shape, stride, dtype) of the input to DropoutAddRMSNorm that would also help me reproduce the error. e.g., before self.q_norm:
input = q.transpose(1, 2).flatten(-2, -1)
print(input.device, input.dtype, input.shape, input.stride())
Thanks, I runned with code:
if self.qk_normalization:
q, k, v = qkv.unbind(2)
if self.qk_normalization_head_merged:
if self.fuse_dal:
input=q.flatten(-2, -1)
print("q:",input.device, input.dtype, input.shape, input.stride(), flush=True)
q = self.q_norm(input)[0].view(q.shape)
torch.cuda.synchronize()
input=k.flatten(-2, -1)
print("k:",input.device, input.dtype, input.shape, input.stride(), flush=True)
k = self.k_norm(input)[0].view(k.shape)
torch.cuda.synchronize()
got two lines of print before error:
q: cuda:0 torch.bfloat16 torch.Size([2, 257, 1200]) (925200, 3600, 1)
k: cuda:0 torch.bfloat16 torch.Size([2, 257, 1200]) (925200, 3600, 1)
I tried with a code that runs well, It has the same print for these two tensors:
q: cuda:0 torch.bfloat16 torch.Size([2, 257, 1200]) (925200, 3600, 1)
k: cuda:0 torch.bfloat16 torch.Size([2, 257, 1200]) (925200, 3600, 1)
q: cuda:0 torch.bfloat16 torch.Size([2, 257, 1200]) (925200, 3600, 1)
k: cuda:0 torch.bfloat16 torch.Size([2, 257, 1200]) (925200, 3600, 1)
q: cuda:0 torch.bfloat16 torch.Size([2, 257, 1200]) (925200, 3600, 1)
k: cuda:0 torch.bfloat16 torch.Size([2, 257, 1200]) (925200, 3600, 1)
q: cuda:0 torch.bfloat16 torch.Size([2, 257, 1200]) (925200, 3600, 1)
k: cuda:0 torch.bfloat16 torch.Size([2, 257, 1200]) (925200, 3600, 1)
q: cuda:0 torch.bfloat16 torch.Size([2, 257, 1200]) (925200, 3600, 1)
To reproduce, need to use this deepspeed: https://github.com/KimmiShi/DeepSpeed/tree/bug
, clone and pip install this repo.
And flash-attn with layer_norm extention is also needed. I am using pt1.13+cu117. I tried pt2.0 it also reproduces.
And run with torch distributed, I am running with slurm: srun -p partition -n1 -N1 --gres=gpu:1 compute-sanitizer python deepspeed/debug_err.py
;
If not using slurm, need to modify line 599, setup_distributed_slurm
to setup torch distributed corretly.
Thanks for the repro script, I've narrowed it down to a memory alignment problem.
We expect all input tensors to be aligned to 16 bytes (in order to use vectorized load for max performance).
Right now the error is due to the weight
tensor inside RMSNorm
not being aligned to 16 bytes.
I thought it was sufficient to ensure that the last dimensions are all divisible by 8 and we call .contiguous()
, and that would lead to 16 byte alignment. Turns out that's not the case. I'll need to figure out a way to ensure memory alignement.
I don't know a reliable way to get 16 bytes alignment, but I've posted a question to Pytorch forum.
I pushed a commit to (hopefully) make sure that memory addresses are aligned by 16 bytes by cloning the inputs.
Thanks! @tridao , the layer_norm is working now.
However, I came across another error when running debug_err.py
.
In debug_err.py
the fused_mlp
is used by default, here is the error print:
envs/pt13/lib/python3.9/site-packages/flash_attn/ops/fused_dense.py", line 257, in forward
output1, *rest = fused_dense_cuda.linear_act_forward(
RuntimeError: linear_act_forward failed
I didn't get extra info using compute-santinizer.
When I set use_fused_mlp=False
, there is no error.
Can you reproduce this error?
Yes, I can reproduce it.
I don't have the bandwidth right now to debug it. I'm not familiar with DeepSpeed, I suspect it puts all parameters in a buffer and that can cause alignment issue.
For now I recommend not using fused_mlp
with DeepSpeed.
ok, thanks a lot. I also suspect that this error is related to DeepSpeed, I will look into this issuse.
I found out that this alignment issue might be related to deepspeed flattening all the parameters of a model into a contiguous memory. This can be triggered when using bf16
or zero ddp
. And when using fp16
with user defined adam
opt (not fusedadam), the paramters are not changed, and this error would not occure.
example: bf16 optimizer that flattens parameters: https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/runtime/bf16_optimizer.py#L105
@ptrblck @tridao FYI, thanks a lot.
I found out that this alignment issue might be related to deepspeed flattening all the parameters of a model into a contiguous memory. This can be triggered when using
bf16
orzero ddp
. And when usingfp16
with user definedadam
opt (not fusedadam), the paramters are not changed, and this error would not occure.example: bf16 optimizer that flattens parameters: https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/runtime/bf16_optimizer.py#L105
@ptrblck @tridao FYI, thanks a lot.
So what should we do if we want to keep bf16 training?
I found out that this alignment issue might be related to deepspeed flattening all the parameters of a model into a contiguous memory. This can be triggered when using
bf16
orzero ddp
. And when usingfp16
with user definedadam
opt (not fusedadam), the paramters are not changed, and this error would not occure. example: bf16 optimizer that flattens parameters: https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/runtime/bf16_optimizer.py#L105 @ptrblck @tridao FYI, thanks a lot.So what should we do if we want to keep bf16 training?
@SCZwangxiao as @tridao suggested: For now I recommend not using fused_mlp with DeepSpeed.
Hi, I have a program, which produces a tensor that will be computed by
DropoutAddRMSNorm
. I got a error message:then tensor q look ok, and then if I try to compute
I get:
(Pdb) *** RuntimeError: CUDA error: misaligned address
How to check if the tensor is misaligned?