stanford-crfm / levanter

Legible, Scalable, Reproducible Foundation Models with Named Tensors and Jax
https://levanter.readthedocs.io/en/latest/
Apache License 2.0
490 stars 78 forks source link

Fatal Python error (Aborted) when run gpt2_test.py #697

Closed DwarKapex closed 1 week ago

DwarKapex commented 3 weeks ago

Hi all.

In JAX-Toolbox we see the following error for both V100 and A100:

opt/levanter/tests/gpt2_test.py Fatal Python error: Aborted

Thread 0x00007f02ebfff640 (most recent call first):
  File "/usr/lib/python3.10/threading.py", line 324 in wait
  File "/usr/lib/python3.10/threading.py", line 607 in wait
  File "/usr/local/lib/python3.10/dist-packages/tqdm/_monitor.py", line 60 in run
  File "/usr/lib/python3.10/threading.py", line 1016 in _bootstrap_inner
  File "/usr/lib/python3.10/threading.py", line 973 in _bootstrap

Current thread 0x00007f0c4cd781c0 (most recent call first):
  File "/opt/jax/jax/_src/compiler.py", line 266 in backend_compile
  File "/opt/jax/jax/_src/profiler.py", line 333 in wrapper
  File "/opt/jax/jax/_src/compiler.py", line 654 in _compile_and_write_cache
  File "/opt/jax/jax/_src/compiler.py", line 426 in compile_or_get_cached
  File "/opt/jax/jax/_src/interpreters/pxla.py", line 2670 in _cached_compilation
  File "/opt/jax/jax/_src/interpreters/pxla.py", line 2857 in from_hlo
  File "/opt/jax/jax/_src/interpreters/pxla.py", line 2344 in compile
  File "/opt/jax/jax/_src/pjit.py", line 1674 in _pjit_call_impl_python
  File "/opt/jax/jax/_src/pjit.py", line 1744 in call_impl_cache_miss
  File "/opt/jax/jax/_src/pjit.py", line 1768 in _pjit_call_impl
  File "/opt/jax/jax/_src/core.py", line 939 in process_primitive
  File "/opt/jax/jax/_src/core.py", line 433 in bind_with_trace
  File "/opt/jax/jax/_src/core.py", line 2761 in bind
  File "/opt/jax/jax/_src/pjit.py", line 190 in _python_pjit_helper
  File "/opt/jax/jax/_src/pjit.py", line 356 in cache_miss
  File "/opt/jax/jax/_src/traceback_util.py", line 180 in reraise_with_filtered_traceback
  File "/opt/haliax/src/haliax/__init__.py", line 875 in subtract
  File "/opt/haliax/src/haliax/wrap.py", line 97 in binop
  File "/opt/haliax/src/haliax/core.py", line 577 in __sub__
  File "/opt/haliax/src/haliax/nn/normalization.py", line 47 in __call__
  File "/opt/levanter/src/levanter/models/gpt2.py", line 299 in __call__
  File "/usr/lib/python3.10/contextlib.py", line 79 in inner
  File "/opt/levanter/src/levanter/models/gpt2.py", line [399](https://github.com/NVIDIA/JAX-Toolbox/actions/runs/10495491812/job/29074977755?pr=1009#step:7:400) in __call__
  File "/opt/levanter/tests/gpt2_test.py", line 41 in test_gradient_checkpointing
  File "/usr/local/lib/python3.10/dist-packages/_pytest/python.py", line 159 in pytest_pyfunc_call
  File "/usr/local/lib/python3.10/dist-packages/pluggy/_callers.py", line 103 in _multicall
  File "/usr/local/lib/python3.10/dist-packages/pluggy/_manager.py", line 120 in _hookexec
  File "/usr/local/lib/python3.10/dist-packages/pluggy/_hooks.py", line 513 in __call__
  File "/usr/local/lib/python3.10/dist-packages/_pytest/python.py", line 1627 in runtest
  File "/usr/local/lib/python3.10/dist-packages/_pytest/runner.py", line 174 in pytest_runtest_call
  File "/usr/local/lib/python3.10/dist-packages/pluggy/_callers.py", line 103 in _multicall
  File "/usr/local/lib/python3.10/dist-packages/pluggy/_manager.py", line 120 in _hookexec
  File "/usr/local/lib/python3.10/dist-packages/pluggy/_hooks.py", line 513 in __call__
  File "/usr/local/lib/python3.10/dist-packages/_pytest/runner.py", line 242 in <lambda>
  File "/usr/local/lib/python3.10/dist-packages/_pytest/runner.py", line 341 in from_call
  File "/usr/local/lib/python3.10/dist-packages/_pytest/runner.py", line 241 in call_and_report
  File "/usr/local/lib/python3.10/dist-packages/_pytest/runner.py", line 132 in runtestprotocol
  File "/usr/local/lib/python3.10/dist-packages/_pytest/runner.py", line 113 in pytest_runtest_protocol
  File "/usr/local/lib/python3.10/dist-packages/pluggy/_callers.py", line 103 in _multicall
  File "/usr/local/lib/python3.10/dist-packages/pluggy/_manager.py", line 120 in _hookexec
  File "/usr/local/lib/python3.10/dist-packages/pluggy/_hooks.py", line 513 in __call__
  File "/usr/local/lib/python3.10/dist-packages/_pytest/main.py", line 362 in pytest_runtestloop
  File "/usr/local/lib/python3.10/dist-packages/pluggy/_callers.py", line 103 in _multicall
  File "/usr/local/lib/python3.10/dist-packages/pluggy/_manager.py", line 120 in _hookexec
  File "/usr/local/lib/python3.10/dist-packages/pluggy/_hooks.py", line 513 in __call__
  File "/usr/local/lib/python3.10/dist-packages/_pytest/main.py", line 337 in _main
  File "/usr/local/lib/python3.10/dist-packages/_pytest/main.py", line 283 in wrap_session
  File "/usr/local/lib/python3.10/dist-packages/_pytest/main.py", line 330 in pytest_cmdline_main
  File "/usr/local/lib/python3.10/dist-packages/pluggy/_callers.py", line 103 in _multicall
  File "/usr/local/lib/python3.10/dist-packages/pluggy/_manager.py", line 120 in _hookexec
  File "/usr/local/lib/python3.10/dist-packages/pluggy/_hooks.py", line 513 in __call__
  File "/usr/local/lib/python3.10/dist-packages/_pytest/config/__init__.py", line 175 in main
  File "/usr/local/lib/python3.10/dist-packages/_pytest/config/__init__.py", line 201 in console_main
  File "/usr/local/bin/pytest", line 8 in <module>

Could you please take a look? Thanks in advance!

dlwh commented 3 weeks ago

seems like a JAX error?

DwarKapex commented 1 week ago

Can be closed

DwarKapex commented 1 week ago

It was indeed problem on JAX side.