In JAX-Toolbox we see the following error for both V100 and A100:
opt/levanter/tests/gpt2_test.py Fatal Python error: Aborted
Thread 0x00007f02ebfff640 (most recent call first):
File "/usr/lib/python3.10/threading.py", line 324 in wait
File "/usr/lib/python3.10/threading.py", line 607 in wait
File "/usr/local/lib/python3.10/dist-packages/tqdm/_monitor.py", line 60 in run
File "/usr/lib/python3.10/threading.py", line 1016 in _bootstrap_inner
File "/usr/lib/python3.10/threading.py", line 973 in _bootstrap
Current thread 0x00007f0c4cd781c0 (most recent call first):
File "/opt/jax/jax/_src/compiler.py", line 266 in backend_compile
File "/opt/jax/jax/_src/profiler.py", line 333 in wrapper
File "/opt/jax/jax/_src/compiler.py", line 654 in _compile_and_write_cache
File "/opt/jax/jax/_src/compiler.py", line 426 in compile_or_get_cached
File "/opt/jax/jax/_src/interpreters/pxla.py", line 2670 in _cached_compilation
File "/opt/jax/jax/_src/interpreters/pxla.py", line 2857 in from_hlo
File "/opt/jax/jax/_src/interpreters/pxla.py", line 2344 in compile
File "/opt/jax/jax/_src/pjit.py", line 1674 in _pjit_call_impl_python
File "/opt/jax/jax/_src/pjit.py", line 1744 in call_impl_cache_miss
File "/opt/jax/jax/_src/pjit.py", line 1768 in _pjit_call_impl
File "/opt/jax/jax/_src/core.py", line 939 in process_primitive
File "/opt/jax/jax/_src/core.py", line 433 in bind_with_trace
File "/opt/jax/jax/_src/core.py", line 2761 in bind
File "/opt/jax/jax/_src/pjit.py", line 190 in _python_pjit_helper
File "/opt/jax/jax/_src/pjit.py", line 356 in cache_miss
File "/opt/jax/jax/_src/traceback_util.py", line 180 in reraise_with_filtered_traceback
File "/opt/haliax/src/haliax/__init__.py", line 875 in subtract
File "/opt/haliax/src/haliax/wrap.py", line 97 in binop
File "/opt/haliax/src/haliax/core.py", line 577 in __sub__
File "/opt/haliax/src/haliax/nn/normalization.py", line 47 in __call__
File "/opt/levanter/src/levanter/models/gpt2.py", line 299 in __call__
File "/usr/lib/python3.10/contextlib.py", line 79 in inner
File "/opt/levanter/src/levanter/models/gpt2.py", line [399](https://github.com/NVIDIA/JAX-Toolbox/actions/runs/10495491812/job/29074977755?pr=1009#step:7:400) in __call__
File "/opt/levanter/tests/gpt2_test.py", line 41 in test_gradient_checkpointing
File "/usr/local/lib/python3.10/dist-packages/_pytest/python.py", line 159 in pytest_pyfunc_call
File "/usr/local/lib/python3.10/dist-packages/pluggy/_callers.py", line 103 in _multicall
File "/usr/local/lib/python3.10/dist-packages/pluggy/_manager.py", line 120 in _hookexec
File "/usr/local/lib/python3.10/dist-packages/pluggy/_hooks.py", line 513 in __call__
File "/usr/local/lib/python3.10/dist-packages/_pytest/python.py", line 1627 in runtest
File "/usr/local/lib/python3.10/dist-packages/_pytest/runner.py", line 174 in pytest_runtest_call
File "/usr/local/lib/python3.10/dist-packages/pluggy/_callers.py", line 103 in _multicall
File "/usr/local/lib/python3.10/dist-packages/pluggy/_manager.py", line 120 in _hookexec
File "/usr/local/lib/python3.10/dist-packages/pluggy/_hooks.py", line 513 in __call__
File "/usr/local/lib/python3.10/dist-packages/_pytest/runner.py", line 242 in <lambda>
File "/usr/local/lib/python3.10/dist-packages/_pytest/runner.py", line 341 in from_call
File "/usr/local/lib/python3.10/dist-packages/_pytest/runner.py", line 241 in call_and_report
File "/usr/local/lib/python3.10/dist-packages/_pytest/runner.py", line 132 in runtestprotocol
File "/usr/local/lib/python3.10/dist-packages/_pytest/runner.py", line 113 in pytest_runtest_protocol
File "/usr/local/lib/python3.10/dist-packages/pluggy/_callers.py", line 103 in _multicall
File "/usr/local/lib/python3.10/dist-packages/pluggy/_manager.py", line 120 in _hookexec
File "/usr/local/lib/python3.10/dist-packages/pluggy/_hooks.py", line 513 in __call__
File "/usr/local/lib/python3.10/dist-packages/_pytest/main.py", line 362 in pytest_runtestloop
File "/usr/local/lib/python3.10/dist-packages/pluggy/_callers.py", line 103 in _multicall
File "/usr/local/lib/python3.10/dist-packages/pluggy/_manager.py", line 120 in _hookexec
File "/usr/local/lib/python3.10/dist-packages/pluggy/_hooks.py", line 513 in __call__
File "/usr/local/lib/python3.10/dist-packages/_pytest/main.py", line 337 in _main
File "/usr/local/lib/python3.10/dist-packages/_pytest/main.py", line 283 in wrap_session
File "/usr/local/lib/python3.10/dist-packages/_pytest/main.py", line 330 in pytest_cmdline_main
File "/usr/local/lib/python3.10/dist-packages/pluggy/_callers.py", line 103 in _multicall
File "/usr/local/lib/python3.10/dist-packages/pluggy/_manager.py", line 120 in _hookexec
File "/usr/local/lib/python3.10/dist-packages/pluggy/_hooks.py", line 513 in __call__
File "/usr/local/lib/python3.10/dist-packages/_pytest/config/__init__.py", line 175 in main
File "/usr/local/lib/python3.10/dist-packages/_pytest/config/__init__.py", line 201 in console_main
File "/usr/local/bin/pytest", line 8 in <module>
Hi all.
In JAX-Toolbox we see the following error for both V100 and A100:
Could you please take a look? Thanks in advance!