EasonXiao-888 / GrootVL

The official implementation of GrootVL: Tree Topology is All You Need in State Space Model
58 stars 2 forks source link

An illegal memory access was encountered in BFS #1

Closed Vp-SoLo closed 3 months ago

Vp-SoLo commented 3 months ago

Thank you for your excellent work! I am preparing to make modifications based on your work. When I wrote the following code to analyze GrootV, after troubleshooting, the BFS part caused a CUDA kernel launch failed: an illegal memory access was encountered error. Here is my test code:

import torch

from classification.models.grootv import GrootV

device = torch.device('cuda:1')

model = GrootV(
        num_classes=10,
        channels=80,
        depths=[2, 2, 9, 2],
        layer_scale=None,
        post_norm=False,
        mlp_ratio=4.0,
        with_cp=False,
        drop_path_rate=0.1,
        ).to(device)

x = torch.rand(8, 3, 64, 64).to(device)

x = model(x)

My CUDA version is 11.8. How can I fix this bug ?

EasonXiao-888 commented 3 months ago

@Vp-SoLo
thanks for your interest and sorry for not replying in time! when i try your demo provided it will also meet the bug, but i change device = torch.device('cuda:1') to os.environ['CUDA_VISIBLE_DEVICES']='7' device = torch.device('cuda') it will be successful, maybe you can try in this manner?

Vp-SoLo commented 3 months ago

@EasonXiao-888 Thank you very much for your answer, I also found a way to solve this bug by modifying code to:

import torch

from classification.models.grootv import GrootV

model = GrootV(
        num_classes=10,
        channels=80,
        depths=[2, 2, 9, 2],
        layer_scale=None,
        post_norm=False,
        mlp_ratio=4.0,
        with_cp=False,
        drop_path_rate=0.1,
        ).cuda()

x = torch.rand(8, 3, 64, 64).cuda()

x = model(x)

It seems that when CUDA_VISIBLE_DEVICES contains more than one CUDA device, artificially specifying a CUDA device will cause this error. I think there is probably some bug in the CUDA source code related to BFS? All in all Now I am able to use GrootV in my work, thanks for your help!