Closed shiyi099 closed 1 day ago
This seems more like a JAX issue than an issue with TAPIR. I unfortunately don't have easy access to H800 GPUs and so I can't easily reproduce this issue. Your best bet might be to try to isolate the op which is producing different outputs on the two different devices, and then create a simple reproduction of the issue that you can use to file a bug against JAX or XLA.
This seems more like a JAX issue than an issue with TAPIR. I unfortunately don't have easy access to H800 GPUs and so I can't easily reproduce this issue. Your best bet might be to try to isolate the op which is producing different outputs on the two different devices, and then create a simple reproduction of the issue that you can use to file a bug against JAX or XLA.
This seems more like a JAX issue than an issue with TAPIR. I unfortunately don't have easy access to H800 GPUs and so I can't easily reproduce this issue. Your best bet might be to try to isolate the op which is producing different outputs on the two different devices, and then create a simple reproduction of the issue that you can use to file a bug against JAX or XLA.
Thank you! It seems disabled whatever i do any configuration on jax. I tried to change my jax version into 0.4.x (cudnn88). Unfortunately, the same prompts are presented. https://github.com/google/jax/issues/17523 meets the simliar problem. Now I have tried using pytorch models of TAPNET on H800 and modified some codes at the folder of tapnet/pytorch. It works!
My configure of Python package is jax == 0.4.13 jaxlib == 0.4.13+cu12+cudnn89
My Hardware is Nvidia-H800
On Ubuntu20.04 LTS x86_64
When I try to run inference of TAPNET or TAPIR models, it shows “None of the algorithms provided by cuDNN heuristics worked; trying fallback algorithms”, and the outputs seem wrong and they are different from those on Windows(RTX3060). What's wrong with it? How can I solve it?