OpenGVLab / Vision-RWKV

Vision-RWKV: Efficient and Scalable Visual Perception with RWKV-Like Architectures
https://arxiv.org/abs/2403.02308
Apache License 2.0
288 stars 11 forks source link

CUDA error for the RWKV6 testing #13

Closed dongzhuoyao closed 2 months ago

dongzhuoyao commented 2 months ago

Hi, I did the following simple test for RWKV6, but it shows the following error.

Have you met this before? could you share me your detailed python,cuda,torch version?

model = RWKV6_Model(
      )

x = torch.rand(10, in_channels, img_dim, img_dim).to("cuda")
o = model(x)
o.backward(torch.randn_like(x))

Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass RuntimeError: CUDA error: an illegal memory access was encountered CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect. For debugging consider passing CUDA_LAUNCH_BLOCKING=1.

duanduanduanyuchen commented 2 months ago

Hi, maybe you should try sth like: o.backward(torch.randn_like(o)) This error seems like an OOM error.

dongzhuoyao commented 2 months ago

thanks for your reply, it doens't solve this issue, I somehow figure out the issue is here:

https://github.com/OpenGVLab/Vision-RWKV/blob/9e458035730790ca2faafaed1b2d9446f5b80b1d/classification/mmcls_custom/models/backbones/vrwkv6.py#L69

but has no clue to how to solve it

duanduanduanyuchen commented 2 months ago

Will this error still happen if you reduce the batch size or img size?

dongzhuoyao commented 2 months ago

yes, it still have, and the gpu is only occupied by me(A100)

duanduanduanyuchen commented 2 months ago

Thanks for your comment. I'll tell you if I figure it out later.

dongzhuoyao commented 2 months ago
image
dongzhuoyao commented 2 months ago

python3.11, cuda11.8, torch2.2.0

dongzhuoyao commented 2 months ago

For RWKV4, I can run it successfully

duanduanduanyuchen commented 2 months ago

Hi, maybe increase T_MAX in vrwkv6.py can solve this.

dongzhuoyao commented 2 months ago

thanks, increase T_max works for me