google-deepmind / tapnet

Tracking Any Point (TAP)
https://deepmind-tapir.github.io/blogpost.html
Apache License 2.0
1.32k stars 128 forks source link

None of the algorithms provided by cuDNN heuristics worked; trying fallback algorithms #86

Closed shiyi099 closed 1 day ago

shiyi099 commented 8 months ago

My configure of Python package is jax == 0.4.13 jaxlib == 0.4.13+cu12+cudnn89

My Hardware is Nvidia-H800

On Ubuntu20.04 LTS x86_64

When I try to run inference of TAPNET or TAPIR models, it shows “None of the algorithms provided by cuDNN heuristics worked; trying fallback algorithms”, and the outputs seem wrong and they are different from those on Windows(RTX3060). What's wrong with it? How can I solve it?

cdoersch commented 8 months ago

This seems more like a JAX issue than an issue with TAPIR. I unfortunately don't have easy access to H800 GPUs and so I can't easily reproduce this issue. Your best bet might be to try to isolate the op which is producing different outputs on the two different devices, and then create a simple reproduction of the issue that you can use to file a bug against JAX or XLA.

shiyi099 commented 8 months ago

This seems more like a JAX issue than an issue with TAPIR. I unfortunately don't have easy access to H800 GPUs and so I can't easily reproduce this issue. Your best bet might be to try to isolate the op which is producing different outputs on the two different devices, and then create a simple reproduction of the issue that you can use to file a bug against JAX or XLA.

This seems more like a JAX issue than an issue with TAPIR. I unfortunately don't have easy access to H800 GPUs and so I can't easily reproduce this issue. Your best bet might be to try to isolate the op which is producing different outputs on the two different devices, and then create a simple reproduction of the issue that you can use to file a bug against JAX or XLA.

Thank you! It seems disabled whatever i do any configuration on jax. I tried to change my jax version into 0.4.x (cudnn88). Unfortunately, the same prompts are presented. https://github.com/google/jax/issues/17523 meets the simliar problem. Now I have tried using pytorch models of TAPNET on H800 and modified some codes at the folder of tapnet/pytorch. It works!