I run the file csm_triton.py but keep getting this error:
File "/workspace/source/VMamba/classification/models/csm_triton.py", line 720, in <module>
CHECK.check_csm_triton()
File "/workspace/source/VMamba/classification/models/csm_triton.py", line 610, in check_csm_triton
res0 = triton.testing.do_bench(lambda :cross_scan(x))
File "/opt/conda/lib/python3.10/site-packages/triton/testing.py", line 121, in do_bench
fn()
File "/workspace/source/VMamba/classification/models/csm_triton.py", line 610, in <lambda>
res0 = triton.testing.do_bench(lambda :cross_scan(x))
File "/workspace/source/VMamba/classification/models/csm_triton.py", line 555, in cross_scan
xs = torch.stack([
torch.cuda.OutOfMemoryError: CUDA out of memory.
I do not understand why this function keeps getting out of memory, and triton.testing.do_bench(lambda :cross_scan(x)) cannot be executed to benchmark the execution time of this function:
def cross_scan(x: torch.Tensor):
# why out of memory?
B, C, H, W = x.shape
L = H * W # Reshape
xs = torch.stack([
# reshape
x.view(B, C, L),
# transpose
torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, C, L),
# flip using L dim
torch.flip(x.contiguous().view(B, C, L), dims=[-1]),
# transpose L dim, reshape into BCL,
torch.flip(torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, C, L), dims=[-1]),
], dim=1).view(B, 4, C, L)
return xs
This is only a simple manipulation of a tensor, why is there 'out of memory' problem. Thank you.
I run the file csm_triton.py but keep getting this error:
I do not understand why this function keeps getting out of memory, and triton.testing.do_bench(lambda :cross_scan(x)) cannot be executed to benchmark the execution time of this function:
This is only a simple manipulation of a tensor, why is there 'out of memory' problem. Thank you.