Open monk279 opened 2 months ago
I faced the same problem, found a solution, came here to ask the developers for changes and saw your question :) This problem is solved by downgrading jax:
!pip install "jax[cuda12_pip]==0.4.23" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
"module 'jax.random' has no attribute 'KeyArray'" occurs at the second block in both two Google Colab notebooks. It seems like the something goes wrong with the diffuerser from huggingface.