Closed miss-rain closed 2 years ago
@miss-rain I think it is related to your GPU type and cuda version and whether are them compatible with the installed jax version. It is not related to this code particularly.
Could you pls first make sure you can install jax and use jax to check GPU is detected (using this https://github.com/google/jax/issues/971).
Please refer to issue #1.
Both in Google Colab with TPU and my own Unbuntu with Cuda,
your code only run in CPU, not run in GPU or TPU!