google-deepmind / gemma

Open weights LLM from Google DeepMind.
http://ai.google.dev/gemma
Apache License 2.0
2.2k stars 268 forks source link

Issue with unit tests on NVIdia V100 (GPU) #32

Open DwarKapex opened 1 month ago

DwarKapex commented 1 month ago

Hi everyone.

I see the issue when run unit tests on NVidia V100 (GPU). Here is the link for more details.

Briefly:


=========================== short test summary info ============================
FAILED opt/gemma/gemma/layers_test.py::EinsumTest::test_rmsnorm0 - AssertionE...
FAILED opt/gemma/gemma/modules_test.py::FeedForwardTest::test_ffw0 - Assertio...
FAILED opt/gemma/gemma/positional_embeddings_test.py::PositionalEmbeddingsTest::test_adds_positional_embeddings0
================== 3 failed, 13 passed, 2 warnings in 35.61s ===================```

Some details:
1. test_rmsnorm0 ([link](https://github.com/NVIDIA/JAX-Toolbox/actions/runs/9099672951/job/25013672689?pr=590#step:7:348)). Looks like this is an EPS-error. I don't think it's a good idea to compare expected array of floats with resulted one. Is it possible to add some discrepancy between expected and calculated arrays? Like `rtol=1e-5, atol=1e-5`?
2. test_ffw0 ([link](https://github.com/NVIDIA/JAX-Toolbox/actions/runs/9099672951/job/25013672689?pr=590#step:7:415)) is similar to previous one.
3. test_adds_positional_embeddings0 [link](https://github.com/NVIDIA/JAX-Toolbox/actions/runs/9099672951/job/25013672689?pr=590#step:7:486). IMHO, jax cannot digest is correctly on GPUs

Thank you for your help! Hope it's fixable! =)
DwarKapex commented 3 weeks ago

Hi everyone. Any update on this problem?