Open hongzhengdong opened 1 year ago
That's weird. I wouldn't worry about it, the errors here seem small enough that they're probably due to numerical instability of the hardware running the code (maybe the tests are running on a GPU but the thresholds were tuned to CPU?). I would just ignore this and see if the rest of the code gives you good results. In the meantime I will increase some of the numerical thresholds of the tests.
What hardware are you running this on? The test breakage in the dataset code looks unusual to me. Do you have multiple GPUs? I think the unit tests might assume you have just a single accelerator.
Hi, I am encountering this problem when I run ./scripts/run_all_unit_tests.sh after setting up environment as you illustrated:
my CUDA version is 11.3, with a CUDNN version 8.3.2, and I am using RTX 3090
the failed position are as follows:
I used the original numpy to run the test and there is not any problem running those test, so I am wondering whether it is the JAX issue?
I tried several different JAX and JAXlib version which still don't fix it
So, could you please help me in fixing this problem?