state-spaces / mamba

Mamba SSM architecture
Apache License 2.0
13.06k stars 1.11k forks source link

Loss NaN in Mamba2 #352

Closed tyshiwo1 closed 4 months ago

tyshiwo1 commented 5 months ago

Hello guys,

When I applied Mamba2 to image generation, I found several NaN values in the gradients (ddt_bias, dx, and ddt_given) in _mamba_chunk_scan_combined_bwd of mamba_ssm/ops/triton/ssd_combined.py, therefore the loss is NaN.

The image generation code is DiM. I just replaced the original Mamba-1 block with Mamba-2. I used the bf16 precision for training from scratch, and the NaN appears in the first training iteration.

My environment is triton==2.2.0, torch==2.2.1+cu121.

If anyone can help me, I will be very grateful! nan

zzzendurance commented 5 months ago

很抱歉打扰你,我没有遇到和你一样的问题,但是我想请教你一些问题。

1.你有没有遇到下面这种问题。 File "/data/zh/miniconda3/envs/man/lib/python3.10/site-packages/torch/cuda/amp/autocast_mode.py", line 113, in decorate_fwd return fwd(*args, **kwargs) File "/data/zh/wa1/-main/IP/mamba/mamba_ssm/ops/triton/ssd_combined.py", line 757, in forward causal_conv1d_cuda.causal_conv1d_fwd(rearrange(xBC, "b s d -> b d s"), TypeError: causal_conv1d_fwd(): incompatible function arguments. The following argument types are supported: (arg0: torch.Tensor, arg1: torch.Tensor, arg2: Optional[torch.Tensor], arg3: Optional[torch.Tensor], arg4: bool) -> torch.Tensor 这个我查到是因为causal_conv1d版本问题,我的causal_conv1d版本原本是1.1.1,我在更换版本为1.0.2后,他依旧报错。你的causal_conv1d版本是什么?

2.mamba2你是怎么直接用于自己项目的?是下载他那个whl文件,然后更新你虚拟环境中的mamba_ssm包吗?还是下载他整个项目文件,然后用的是项目文件里的mamba2文件? 如果你下载的是whl文件,那你是怎么选版本的?这里的abi 真否有何区别 mamba ssm-2.0.3+cu118torch1.13cxx11abiTRUE-cp310-cp310-inux x86 64.whl mamba ssm-2.0.3+cu118torch1.13cxx11abiFALSE-cp310-cp310-inux x86 64.whl

很抱歉打扰你,期待回复捏~

tyshiwo1 commented 5 months ago
  1. my casual conv 1d version is 1.2.2.post1
  2. I download the whole project and compile it locally using CAUSAL_CONV1D_FORCE_BUILD=TRUE pip install --user -e .
zzzendurance commented 5 months ago

nan 哈哈哈哈 ,感谢,我更新了我的casual conv 1d version为1.2.2.post1,成功跑起来了,然后遇到了和你一样的问题?之前mamba1同样的超参数没有nan过好像

tyshiwo1 commented 5 months ago

Yes, the same issue

tyshiwo1 commented 5 months ago

I found this line of code causes NaN of dx

https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/triton/ssd_combined.py#L176

Kiet0712 commented 5 months ago

have you try to use float32 instead of bfloat16 ?

tyshiwo1 commented 5 months ago

have you try to use float32 instead of bfloat16 ?

The NaN remains when I use fp32 to train.

zzzendurance commented 5 months ago

wow,Do you know how to change this line of code (acc += tl.dot(cb, dout)) to solve the nan problem? (I'm so bad at it, I don't know how to do it.)

tyshiwo1 commented 5 months ago

wow,Do you know how to change this line of code (acc += tl.dot(cb, dout)) to solve the nan problem? (I'm so bad at it, I don't know how to do it.)

I am trying on it. Also, the code around this line may also cause bug. Besides, other variables like ddt_bias also contain NaN. Maybe a lot of codes need to change.

zzzendurance commented 5 months ago

不明觉厉,我自己也研究研究。

那大佬你说这是因为mamba2本身的bug,还是说是因为大家把他用到了各自的任务上才出现了nan这样的问题(也就是说大家要根据自己的任务适当地 有针对性地修改代码?)?

tyshiwo1 commented 5 months ago

不明觉厉,我自己也研究研究。

那大佬你说这是因为mamba2本身的bug,还是说是因为大家把他用到了各自的任务上才出现了nan这样的问题(也就是说大家要根据自己的任务适当地 有针对性地修改代码?)?

I apply Mamba2 on image generation instead of NLP tasks, and get NaN loss.

tridao commented 5 months ago

Would be hard for us to say what's causing NaN until we can reproduce it. Can you save all the tensors right before the function call that produced NaN and share with us? Sth like

if dx.isnan().any():
    # save tensors to disk with torch.save

So that we can reproduce it like this:

# load x, b, c, dt, etc from disk with torch.load
# whatever function here that caused NaN
# we observe NaN in dx for example.
tyshiwo1 commented 5 months ago

Would be hard for us to say what's causing NaN until we can reproduce it. Can you save all the tensors right before the function call that produced NaN and share with us? Sth like

if dx.isnan().any():
    # save tensors to disk with torch.save

So that we can reproduce it like this:

# load x, b, c, dt, etc from disk with torch.load
# whatever function here that caused NaN
# we observe NaN in dx for example.

Thank you for your reply !!! I have uploaded my tensors to Google drive.

The error occurs on this line, so I save the input and output tensors into a zip file.

I also tried some operations like changing

dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < K_MAX - k, other=0.0).to(tl.float32)

into

dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < K_MAX - k, other=-1e6).to(tl.float32)

, but this does not work on other variables like dout.

XiudingCai commented 4 months ago

A smaller A range (close to 1) and a smaller chunk size may make the training more stable

Thanks for the suggestion, I set the chunk_size to 1 and then the time of the NaN appeared backward, but it still appeared as a NaN.

realwenlongwang commented 4 months ago

I think chunk_size does not really matter. Emperically I found set A_init_range to (1, 1.1) works for me.

XiudingCai commented 4 months ago

I think chunk_size does not really matter. Emperically I found set A_init_range to (1, 1.1) works for me.

It did.

tyshiwo1 commented 4 months ago

I think chunk_size does not really matter. Emperically I found set A_init_range to (1, 1.1) works for me.

This can stabilize the training of the first few iterations, but the loss becomes NaN later.

ZijianYY commented 4 months ago

same issue. In my code, the loss suddenly becomes nan. 截屏2024-06-05 下午9 32 24

tyshiwo1 commented 4 months ago

same issue. In my code, the loss suddenly becomes nan.

Maybe you can also locate the NaN values and provide the tensors like I did 😂.

bio-mlhui commented 4 months ago

if it is still not stable, the last method is to lower down the learning rate

ZijianYY commented 4 months ago

same issue. In my code, the loss suddenly becomes nan.

Maybe you can also locate the NaN values and provide the tensors like I did 😂.

Checked. It is also the ddt_bias tensor as you mentioned before.

ddt_bias
ZijianYY commented 4 months ago

if it is still not stable, the last method is to lower down the learning rate

Tried. The result is just the same as decreasing chunk size. It can stabilize for more epochs but loss becomes nan later.

tyshiwo1 commented 4 months ago

Would be hard for us to say what's causing NaN until we can reproduce it. Can you save all the tensors right before the function call that produced NaN and share with us? Sth like

if dx.isnan().any():
    # save tensors to disk with torch.save

So that we can reproduce it like this:

# load x, b, c, dt, etc from disk with torch.load
# whatever function here that caused NaN
# we observe NaN in dx for example.

@ZijianYY You can follow the instructions here. I have uploaded my tensors to https://drive.google.com/drive/folders/1ojmQNDsAToNZaP3ZNOAeMJBu1AshMnXS?usp=sharing for this function https://github.com/state-spaces/mamba/blob/26283fbfa16dd0c2e054aab3f140ce10d3d02e6e/mamba_ssm/ops/triton/ssd_combined.py#L403 . You may also check it.

Maykeye commented 4 months ago

Would be hard for us to say what's causing NaN until we can reproduce it. Can you save all the tensors right before the function call that produced NaN and share with us? Sth like

Here's another a very primitive barebone "model" that very quickly generates NaN at 3080Ti laptop.

The model creates random 8x8 "RGB image" and then try to create "upscaled" 64x64 version using mamba and 64x64 random values that are supposed to represent "I'm a pixel N-th, who am I considering the past?".

Number of layers and d_model matter. dtype doesn't (both bfloat16 and float32 fail at the same epoch). expand, d_state etc are default

With mamba2simple I get ValueError: NaN loss at epoch #2 With mamba2 I get ValueError: NaN loss at epoch #1 With mamba1 I lose patience after 100 iterations: no NaN appears.

Start of file sets parameters(the most important is THE_MAMBA to choose what class is used for mamba: Mamba, Mamba2, Mamba2Simple)

EddieEduardo commented 4 months ago

Same here, raised nan when training using Mamba2

catalpaaa commented 4 months ago

不明觉厉,我自己也研究研究。 那大佬你说这是因为mamba2本身的bug,还是说是因为大家把他用到了各自的任务上才出现了nan这样的问题(也就是说大家要根据自己的任务适当地 有针对性地修改代码?)?

I apply Mamba2 on image generation instead of NLP tasks, and get NaN loss.

i also get NaN on image classification taks :(

zzzendurance commented 4 months ago

不明觉厉,我自己也研究研究。 那大佬你说这是因为mamba2本身的bug,还是说是因为大家把他用到了各自的任务上才出现了nan这样的问题(也就是说大家要根据自己的任务适当地 有针对性地修改代码?)?

I apply Mamba2 on image generation instead of NLP tasks, and get NaN loss.

i also get NaN on image classification taks :(

So is this problem still unresolved?(i get NaN on voice classification tasks)

tridao commented 4 months ago

Thanks for the bug reports. We were able to reproduce the NaN gradients when sequence length is not a multiple of 256. We pushed a fix, can you guys try v2.0.4?

EddieEduardo commented 4 months ago

My sequence length is 256, my task is object detection, it raised nan 😭

---- Replied Message ---- | From | Tri @.> | | Date | 06/12/2024 14:39 | | To | state-spaces/mamba @.> | | Cc | EddieEduardo @.>, Comment @.> | | Subject | Re: [state-spaces/mamba] Loss NaN in Mamba2 (Issue #352) |

Thanks for the bug reports. We were able to reproduce the NaN gradients when sequence length is not a multiple of 256. We pushed a fix, can you guys try v2.0.4?

— Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you commented.Message ID: @.***>

tridao commented 4 months ago

My sequence length is 256, my task is object detection, it raised nan 😭 ---- Replied Message ---- | From | Tri @.> | | Date | 06/12/2024 14:39 | | To | state-spaces/mamba @.> | | Cc | EddieEduardo @.>, Comment @.> | | Subject | Re: [state-spaces/mamba] Loss NaN in Mamba2 (Issue #352) | Thanks for the bug reports. We were able to reproduce the NaN gradients when sequence length is not a multiple of 256. We pushed a fix, can you guys try v2.0.4? — Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you commented.Message ID: @.***>

Please make a reproducible script (e.g. save the tensors right before the function that causes NaN). If we can't reproduce we can't do anything.

tyshiwo1 commented 4 months ago

Thanks for the bug reports. We were able to reproduce the NaN gradients when sequence length is not a multiple of 256. We pushed a fix, can you guys try v2.0.4?

Thank you! We will try it.

tyshiwo1 commented 4 months ago

Thanks for the bug reports. We were able to reproduce the NaN gradients when sequence length is not a multiple of 256. We pushed a fix, can you guys try v2.0.4?

Thank you! We will try it.

This fix works for me. At least the loss remains stable over a few thousand training iterations. Thanks again!

catalpaaa commented 4 months ago

Thanks for the bug reports. We were able to reproduce the NaN gradients when sequence length is not a multiple of 256. We pushed a fix, can you guys try v2.0.4?

Huge fix, training works on d_model = 256 and 512. but once I lower d_model down to 192, the error

RuntimeError: causal_conv1d with channel last layout requires strides (x.stride(0) and x.stride(2)) to be multiples of 8

was produced, if I follow #362 to force the training to run, it will produce nan on any d_model.

Any plan on fixing this issue? Let me know what you need.

Maykeye commented 4 months ago

We pushed a fix, can you guys try v2.0.4?

Works now!

catalpaaa commented 4 months ago

Thanks for the bug reports. We were able to reproduce the NaN gradients when sequence length is not a multiple of 256. We pushed a fix, can you guys try v2.0.4?

Huge fix, training works on d_model = 256 and 512. but once I lower d_model down to 192, the error

RuntimeError: causal_conv1d with channel last layout requires strides (x.stride(0) and x.stride(2)) to be multiples of 8

was produced, if I follow #362 to force the training to run, it will produce nan on any d_model.

Any plan on fixing this issue? Let me know what you need.

my bad for not investigate further, your fix is perfect and we should not use #362 with it, all we need is to make sure d_model * expand / headdim = multiple of 8

TimothyChen225 commented 4 months ago
  1. my casual conv 1d version is 1.2.2.post1
  2. I download the whole project and compile it locally using CAUSAL_CONV1D_FORCE_BUILD=TRUE pip install --user -e .

yes ,it did work for me

drhuangliwei commented 3 months ago

I think chunk_size does not really matter. Emperically I found set A_init_range to (1, 1.1) works for me.

After running for a while, Nan is restored