when use pytest . one test failed:
FAILED gemma/positional_embeddings_test.py::PositionalEmbeddingsTest::test_adds_positional_embeddings0 - ValueError: Operands could not be broadcast together for add on shapes (2, 1, 1
, 5) (5,) and with the config option jax_numpy_rank_promotion='raise'. For more information, see htt...
AND when I running Sampling.py, it failed ,stop at the last ste
IS there any relationship between them?OR other reason of the sampling failure?
My JAX version is CPU,WIN 10 system,cuda 12.1,jax==0.25.0
when use pytest . one test failed: FAILED gemma/positional_embeddings_test.py::PositionalEmbeddingsTest::test_adds_positional_embeddings0 - ValueError: Operands could not be broadcast together for add on shapes (2, 1, 1 , 5) (5,) and with the config option jax_numpy_rank_promotion='raise'. For more information, see htt...
AND when I running Sampling.py, it failed ,stop at the last ste![image](https://github.com/google-deepmind/gemma/assets/164460376/1b878a4a-464c-41a3-a4a1-159d113fae97)
IS there any relationship between them?OR other reason of the sampling failure? My JAX version is CPU,WIN 10 system,cuda 12.1,jax==0.25.0