Hprairie / Bi-Mamba2

A Triton Kernel for incorporating Bi-Directionality in Mamba2
47 stars 1 forks source link

Some questions of Bi-Mamba2 and Hydra #1

Open Liu-zhi-chao opened 2 months ago

Liu-zhi-chao commented 2 months ago

Hi Hayden, thank you very much for this exciting and excellent work! I have a few questions I would like to ask you for help.

(1) First, your Bi-Mamba2 quasiseperable matrix omits the shift() operation compared to the Hydra method. In this case, can your method share the parameters of the two matrixers like Hydra? Are the other processing parameters B, C, etc. consistent with the Hydra paper?

(2) How can I use your Bi-Mamba2? Can I just replace the mamba_chunk_scan_combined() in Hydra with bimamba_chunk_scan_combined() to run it?

Looking forward to your reply, thank you again for your excellent work!

GLOMQuyet commented 2 months ago

I want to send the same question and check how much it reduces GPU VRAM usage. Additionally, I would like to experiment with Hydra in computer vision based on this GitHub repository, as the other Hydra is too slow and not as meaningful as Mamba 1.

Hprairie commented 2 months ago

Hey, thanks for reaching out,

  1. Yes my methods assumes that weights are shared between the two scans. The best way to visualize what the matrix mixing looks like for Bi-Mamba2 is to go to page 6 of the hydra paper, and then look at Figure 2.c. Bi-Mamba2 will be the formulation of the addition of two separate SSMs, plus $Dx(t)$.
  2. I haven't tested it, but yes simply replacing mamba_chunk_scan_combined() in Hydra with bimamba_chunk_scan_combined() should work. I will also be adding a layer implementation soon like Hydra does, to abstract away all of that. Just busy with the start of school rn.
  3. As for the VRAM usage, I have to actually benchmark it, however, it should be less the Hydra. Using PyTorch's flip operation makes a copy in memory. This means that we are already taking up extra space for no reason. Bi-Mamba2 leaves all tensors in-place and then reads in chunks to calculate both the fwd and bwd scans in a single kernel simultaneously. I'll have a figure in the benchmark section when I create something to actually measure it.

I'm glad people are wanting to use this, as the subquadratic nature of SSMs really starts to show with optimized bi-directional kernels.

Lmk if you have any other questions, or if you find any issues with the kernel. I think that it should be very stable rn (i.e. passing all the tests I created), but there might be something tiny overlooked.

Hprairie commented 2 months ago

Oh, I also want to note that this kernel appears to be much faster on NVIDIA gpu's due to tensor cores and better Triton optimizations. So for now I would recommend using them for the best performance, however, AMD will still work.

Liu-zhi-chao commented 2 months ago

Hey, thanks for reaching out,

  1. Yes my methods assumes that weights are shared between the two scans. The best way to visualize what the matrix mixing looks like for Bi-Mamba2 is to go to page 6 of the hydra paper, and then look at Figure 2.c. Bi-Mamba2 will be the formulation of the addition of two separate SSMs, plus D x ( t ) .
  2. I haven't tested it, but yes simply replacing mamba_chunk_scan_combined() in Hydra with bimamba_chunk_scan_combined() should work. I will also be adding a layer implementation soon like Hydra does, to abstract away all of that. Just busy with the start of school rn.
  3. As for the VRAM usage, I have to actually benchmark it, however, it should be less the Hydra. Using PyTorch's flip operation makes a copy in memory. This means that we are already taking up extra space for no reason. Bi-Mamba2 leaves all tensors in-place and then reads in chunks to calculate both the fwd and bwd scans in a single kernel simultaneously. I'll have a figure in the benchmark section when I create something to actually measure it.

I'm glad people are wanting to use this, as the subquadratic nature of SSMs really starts to show with optimized bi-directional kernels.

Lmk if you have any other questions, or if you find any issues with the kernel. I think that it should be very stable rn (i.e. passing all the tests I created), but there might be something tiny overlooked.

Hi Hayden, thanks a lot for your reply! I tried running it by replacing bimamba_chunk_scan_combined() with mamba_chunk_scan_combined() directly in Hydra. But I encountered an error in the back propagation process:

File "/root/CodeSource/Detection/multiYolov8/Bi-Mamba2/src/ssd/bi/ssd_combined.py", line 578, in backward
dx, ddt, dA, dB, dC, dD, dz, ddt_bias, *rest = _mamba_chunk_scan_combined_bwd(dout, x, dt, A, B, C, out, ctx.chunk_size, D=D, z=z, dt_bias=dt_bias, dfinal_states=dfinal_states, dt_softplus=ctx.dt_softplus, dt_limit=ctx.dt_limit)
File "/root/CodeSource/Detection/multiYolov8/Bi-Mamba2/src/ssd/bi/ssd_combined.py", line 349, in _mamba_chunk_scan_combined_bwd assert dt.shape == (batch, seqlen, nheads) AssertionError 

I print out the shape of the dt variable in ssd/bi/ssd_combined.py:

_mamba_chunk_scan_combined_bwd(dout, x, dt, A, B, C, out, chunk_size, D=None, z=None, dt_bias=None, dfinal_states=None, dt_softplus=False, dt_limit=(0.0, float("inf")), dx=None, ddt=None, dB=None, dC=None, dz=None, recompute_output=False): if dout.stride(-1) != 1: dout = dout.contiguous() batch, seqlen, nheads, headdim = x.shape print("ssd_combined bwd shape of dt should be:", batch, seqlen, nheads) print("ssd_combined bwd real shape of dt :", dt.shape)

The print result is:

ssd_combined bwd shape of dt should be: 32 400 8
ssd_combined bwd real shape of dt: torch.Size([32, 8, 2, 256])

It is worth noting that the shape of the forward propagation process dt is correct([32 ,400, 8]).

Liu-zhi-chao commented 2 months ago

I further printed the dt dimension in the forward function of MambaChunkScanCombinedFn and found that its shape at this time was abnormal:

class MambaChunkScanCombinedFn(torch.autograd.Function):

    @staticmethod
    def forward(ctx, x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float("inf")), return_final_states=False):
        ctx.dt_dtype = dt.dtype
        out, out_x, dt, dA_cumsum_f, dA_cumsum_b, states_f, states_b, final_states_f, final_states_b = _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size, D=D, z=z, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit)

        print("MCSCF fwd dt shape:", dt.shape)

Print result: MCSCF fwd dt shape: torch.Size([32, 8, 2, 256]), but i think it should be [32, 400 , 8]

Hprairie commented 2 months ago

Hmm interesting, dt is in the correct chunked shape, but that shouldn't be outputting. Let me take a look and try to fix/create a working API

Hprairie commented 2 months ago

Can you also paste a script so that I can reproduce your error?

Hprairie commented 2 months ago

Okay, some updates. I have fixed the issue with the shape. I was accidentally saving the wrong dt for the bwd pass. This should be good now. After looking through the Hydra module, I have created my own module which replicates a similar module without extra compute/parameters. I have added a module for BiMama2, which people can use now.

I want to note again that I have only tested these modules for numeric results and don't know about the training dynamics. Hopefully, people can try using it and let me know!

Liu-zhi-chao commented 2 months ago

Hi, Hayden, thank you very much for fixing the bug mentioned above. I also tried to replace Bi-Mamba2 with my network. It is a great improvement in speed compared to Hydra, which is an amazing work. I am currently studying how to combine Bi-Mamba2 with Yolov8, but in my preliminary experiments, I found that no matter how I adjust the optimizer type and the size of the learning rate, the loss function is always NAN. It is worth noting that the loss function was normal when using Hydra before. In my further attempts, I found that if I set AMP=True, the loss can be calculated, but this is not what I want. So I would like to ask whether Bi-Mamba2 involves the selection of FP16 and FP32 when implementing the underlying kernel code. Can you give me some good suggestions to help me achieve the normal training process when AMP=False.

Hprairie commented 2 months ago

Hmmm very weird, are the gradient becoming NaN or are the outputs? I.e do you think you can identify where the NaN is coming from. I have a couple things I can check, but it is a little weird that it doesn't work with fp32. I know the original triton kernels used bf16 or fp32, so if training is stable with both of these then idk how much I can do. Giving me more information will be useful as I can't reproduce the issue. A toy training script where this happens would be amazing.

GLOMQuyet commented 2 months ago

Hi, Hayden, thank you very much for fixing the bug mentioned above. I also tried to replace Bi-Mamba2 with my network. It is a great improvement in speed compared to Hydra, which is an amazing work. I am currently studying how to combine Bi-Mamba2 with Yolov8, but in my preliminary experiments, I found that no matter how I adjust the optimizer type and the size of the learning rate, the loss function is always NAN. It is worth noting that the loss function was normal when using Hydra before. In my further attempts, I found that if I set AMP=True, the loss can be calculated, but this is not what I want. So I would like to ask whether Bi-Mamba2 involves the selection of FP16 and FP32 when implementing the underlying kernel code. Can you give me some good suggestions to help me achieve the normal training process when AMP=False.

image As the author said your selsun must be exact otherwise they will give NaN, check if the forward and backward of the segsum are correct

Hprairie commented 2 months ago

Yes, I use a stable segsum calculation method. Currently, my method is arithmetically correct, meaning that you can download this repo and run pytest. Come back in a couple of hours and it should be passing everything. I have not been able to recreate the NaN.

Also as a side note. All SSD kernels use some subtraction for an approximate answer. I won't go into it in this comment but it should be fine. I really need to know which kernel to look at. I.e. where the gradients are first appearing? For example, they could be appearing $dx$ first, which would help me fix the kernel. I can't do much without an input and output tensor which will reproduce the error.

Liu-zhi-chao commented 2 months ago

Thanks to Hayden and Trương for their attention to the questions raised. After many attempts, I found that I needed a larger BatchSize to solve the problem of NAN loss (NAN loss is generally caused by gradient explosion or zero gradient) compared to the original Hydra architecture. But this is a challenge for my GPU memory. Since the entire network model I designed is based on the Yolov8 project, I am not sure how to provide you with more detailed examples to help you find the problem. Anyway, I am very grateful for your patient answers and help!

Liu-zhi-chao commented 2 months ago

Hello, @Hprairie . I did some tests on the problem of NAN gradients caused by increasing the number of Bi-Mamba blocks. I tried to print the gradient values ​​of dx, ddt, dA, etc. during the backpropagation process. The specific code of my test is as follows:

class MambaChunkScanCombinedFn(torch.autograd.Function):
    @staticmethod
    def backward(ctx, dout, *args):
        out, x, dt, dA_cumsum_f, dA_cumsum_b, A, B, C, D, z, dt_bias = ctx.saved_tensors
        assert not ctx.return_final_states, "final states are not currently supported in the bwd pass"
        dfinal_states = args[0] if ctx.return_final_states else None
        dx, ddt, dA, dB, dC, dD, dz, ddt_bias, *rest = _mamba_chunk_scan_combined_bwd(dout, x, dt, A, B, C, out, ctx.chunk_size, D=D, z=z, dt_bias=dt_bias, dfinal_states=dfinal_states, dt_softplus=ctx.dt_softplus, dt_limit=ctx.dt_limit)

        # Print gradients
        dx_values_over_batches.append(dx.mean().item())
        print(f"Batch {len(dx_values_over_batches)} - dx Mean: {dx.mean().item()}, ddt Mean: {ddt.mean().item()}, dA Mean: {dA.mean().item()}, dB Mean: {dB.mean().item()}, dC Mean: {dC.mean().item()}, dD Mean: {dD.mean().item() if dD is not None else 'None'}, dz Mean: {dz.mean().item() if dz is not None else 'None'}, dt_bias Mean: {dt_bias.mean().item() if dt_bias is not None else 'None'}")

        return dx, ddt, dA, dB, dC, None, dD, dz, ddt_bias, None, None, None

Test results:

Batch 1 - dx Mean: -6.49321023615812e-08, ddt Mean: nan, dA Mean: nan, dB Mean: 1.4940897585802304e-07, dC Mean: 2.9840339088593737e-09, dD Mean: 0.00012194819282740355, dz Mean: None, dt_bias Mean: None
Batch 2 - dx Mean: nan, ddt Mean: nan, dA Mean: nan, dB Mean: nan, dC Mean: nan, dD Mean: nan, dz Mean: None, dt_bias Mean: None
Batch 3 - dx Mean: nan, ddt Mean: nan, dA Mean: nan, dB Mean: nan, dC Mean: nan, dD Mean: nan, dz Mean: None, dt_bias Mean: None
Batch 4 - dx Mean: 4.271685583034923e-08, ddt Mean: nan, dA Mean: nan, dB Mean: 1.901244051083495e-09, dC Mean: 3.929191549900679e-08, dD Mean: -0.0003782338462769985, dz Mean: None, dt_bias Mean: None
Batch 5 - dx Mean: nan, ddt Mean: nan, dA Mean: nan, dB Mean: nan, dC Mean: nan, dD Mean: nan, dz Mean: None, dt_bias Mean: None
Batch 6 - dx Mean: nan, ddt Mean: nan, dA Mean: nan, dB Mean: nan, dC Mean: nan, dD Mean: nan, dz Mean: None, dt_bias Mean: None
Batch 7 - dx Mean: nan, ddt Mean: nan, dA Mean: nan, dB Mean: nan, dC Mean: nan, dD Mean: nan, dz Mean: None, dt_bias Mean: None
Batch 8 - dx Mean: nan, ddt Mean: nan, dA Mean: nan, dB Mean: nan, dC Mean: nan, dD Mean: nan, dz Mean: None, dt_bias Mean: None
Batch 9 - dx Mean: nan, ddt Mean: nan, dA Mean: nan, dB Mean: nan, dC Mean: nan, dD Mean: nan, dz Mean: None, dt_bias Mean: None
Batch 10 - dx Mean: nan, ddt Mean: nan, dA Mean: nan, dB Mean: nan, dC Mean: nan, dD Mean: nan, dz Mean: None, dt_bias Mean: None
Batch 11 - dx Mean: nan, ddt Mean: nan, dA Mean: nan, dB Mean: nan, dC Mean: nan, dD Mean: nan, dz Mean: None, dt_bias Mean: None
Batch 12 - dx Mean: nan, ddt Mean: nan, dA Mean: nan, dB Mean: nan, dC Mean: nan, dD Mean: nan, dz Mean: None, dt_bias Mean: None
Batch 13 - dx Mean: nan, ddt Mean: nan, dA Mean: nan, dB Mean: nan, dC Mean: nan, dD Mean: nan, dz Mean: None, dt_bias Mean: None
Batch 14 - dx Mean: nan, ddt Mean: nan, dA Mean: nan, dB Mean: nan, dC Mean: nan, dD Mean: nan, dz Mean: None, dt_bias Mean: None
Batch 15 - dx Mean: nan, ddt Mean: nan, dA Mean: nan, dB Mean: nan, dC Mean: nan, dD Mean: nan, dz Mean: None, dt_bias Mean: None

I found that at the beginning of training, the gradient values ​​of dx, dB, dC, and dD were not NAN at first, but after several rounds of batch training, they all became NAN. In addition, someone in this issue also faced the problem of NAN gradients(found this line of code causes NaN of dx). I don't know if this information can help you.

Hprairie commented 2 months ago

Okay this is super helpful, The NaN's appear in ddt and dA first which is interesting. I'll look into this

Hprairie commented 2 months ago

Can you help me out by printing the each variable at this line https://github.com/Hprairie/Bi-Mamba2/blob/08b3cd3cf6d60ee2f4e712f6efecebe86ec15f92/src/ssd/bi/ssd_combined.py#L466C18-L466C19 . It will identify which kernel the NaN's are coming from.

Liu-zhi-chao commented 2 months ago

OK, I just tested it and the results are as follows:

###Batch 1###
  ddA_b (mean): 8.174757567758206e-06
  ddA_f (mean): 1.179828541353345e-05
  ddA_next_f (mean): 4.442523731995607e-06
  ddA_next_b (mean): 3.3115450150944525e-06
  ddA_prev_b (mean): nan
  ddA_prev_f (mean): nan
  ddA (Total mean): nan
###Batch 2###
  ddA_b (mean): -1.2388027244014665e-05
  ddA_f (mean): 1.1606797670538072e-05
  ddA_next_f (mean): -6.844269137218362e-06
  ddA_next_b (mean): -3.871251010423293e-06
  ddA_prev_b (mean): nan
  ddA_prev_f (mean): nan
  ddA (Total mean): nan
###Batch 3###
  ddA_b (mean): nan
  ddA_f (mean): nan
  ddA_next_f (mean): nan
  ddA_next_b (mean): nan
  ddA_prev_b (mean): nan
  ddA_prev_f (mean): nan
  ddA (Total mean): nan
###Batch 4###
  ddA_b (mean): nan
  ddA_f (mean): nan
  ddA_next_f (mean): nan
  ddA_next_b (mean): nan
  ddA_prev_b (mean): nan
  ddA_prev_f (mean): nan
  ddA (Total mean): nan

The results show that when training starts, the gradients of ddA_prev_b and ddA_prev_f are NAN