MzeroMiko / VMamba

VMamba: Visual State Space Models,code is based on mamba
MIT License
2.21k stars 143 forks source link

Run csm_triton.py keeps getting CUDA out of memory. #294

Open BaophanN opened 2 months ago

BaophanN commented 2 months ago

I run the file csm_triton.py but keep getting this error:

 File "/workspace/source/VMamba/classification/models/csm_triton.py", line 720, in <module>
    CHECK.check_csm_triton()
  File "/workspace/source/VMamba/classification/models/csm_triton.py", line 610, in check_csm_triton
    res0 = triton.testing.do_bench(lambda :cross_scan(x))
  File "/opt/conda/lib/python3.10/site-packages/triton/testing.py", line 121, in do_bench
    fn()
  File "/workspace/source/VMamba/classification/models/csm_triton.py", line 610, in <lambda>
    res0 = triton.testing.do_bench(lambda :cross_scan(x))
  File "/workspace/source/VMamba/classification/models/csm_triton.py", line 555, in cross_scan
    xs = torch.stack([
torch.cuda.OutOfMemoryError: CUDA out of memory.

I do not understand why this function keeps getting out of memory, and triton.testing.do_bench(lambda :cross_scan(x)) cannot be executed to benchmark the execution time of this function:

        def cross_scan(x: torch.Tensor):
            # why out of memory? 

            B, C, H, W = x.shape
            L = H * W # Reshape 
            xs = torch.stack([
                # reshape 
                x.view(B, C, L),
                # transpose 
                torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, C, L),
                # flip using L dim 
                torch.flip(x.contiguous().view(B, C, L), dims=[-1]),
                # transpose L dim, reshape into BCL, 
                torch.flip(torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, C, L), dims=[-1]),
            ], dim=1).view(B, 4, C, L)
            return xs

This is only a simple manipulation of a tensor, why is there 'out of memory' problem. Thank you.

MzeroMiko commented 2 months ago

There're many copy operations in this "simple" manipulations, and needs at least 10x memory size as the tensor x do. That may be the cause of OOM.