jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.18k stars 2.77k forks source link

jax/tests/lax_scipy_sparse_test.py segfaults on GPU; other GPU test failures #5713

Closed skye closed 3 years ago

skye commented 3 years ago

I'm unable to run all unit tests with jaxlib==0.1.60+cuda111. I suspect this is an issue for all GPU builds.

$ python3 -m pytest jax/tests/
================================================================================================================================== test session starts ===================================================================================================================================
platform linux -- Python 3.6.9, pytest-6.1.1, py-1.9.0, pluggy-0.13.1
rootdir: /home/skyewm/jax, configfile: pytest.ini
plugins: xdist-2.1.0, forked-1.3.0
collected 11931 items                                                                                                                                                                                                                                                                    

jax/tests/api_test.py ....................s.................ss.......................................................................................................................................s..............................s.......ss..........s................s........ [  2%]
...............................s......s................................s.ss.......s..........s...................                                                                                                                                                                  [  3%]
jax/tests/api_util_test.py ............                                                                                                                                                                                                                                            [  3%]
jax/tests/array_interoperability_test.py ..........sssssssssssssssssssssssssssssssssssssssssssssssssss                                                                                                                                                                             [  3%]
jax/tests/batching_test.py .......................................ssssssssssss...................................FF................................................................FF....................F...                                                                      [  5%]
jax/tests/callback_test.py .........                                                                                                                                                                                                                                               [  5%]
jax/tests/core_test.py ......................................................................................................................................................................................................................................                      [  7%]
jax/tests/custom_object_test.py ................                                                                                                                                                                                                                                   [  7%]
jax/tests/debug_nans_test.py ..........                                                                                                                                                                                                                                            [  7%]
jax/tests/doubledouble_test.py ..............................................................                                                                                                                                                                                      [  7%]
jax/tests/dtypes_test.py ......................................................................................................................................................................................................................................................... [  9%]
................................................................................................................................                                                                                                                                                   [ 11%]
jax/tests/errors_test.py ssss                                                                                                                                                                                                                                                      [ 11%]
jax/tests/fft_test.py ..........................................................................................                                                                                                                                                                   [ 11%]
jax/tests/generated_fun_test.py ........................                                                                                                                                                                                                                           [ 12%]
jax/tests/host_callback_to_tf_test.py ssssssssss                                                                                                                                                                                                                                   [ 12%]
jax/tests/image_test.py ssssssssssssssssssss..........................................                                                                                                                                                                                             [ 12%]
jax/tests/infeed_test.py ....                                                                                                                                                                                                                                                      [ 12%]
jax/tests/jax_jit_test.py ..............                                                                                                                                                                                                                                           [ 12%]
jax/tests/jax_to_hlo_test.py ..                                                                                                                                                                                                                                                    [ 12%]
jax/tests/jaxpr_util_test.py .....                                                                                                                                                                                                                                                 [ 12%]
jax/tests/jet_test.py ......s.......F........s..........ss...ss..s...s.ss..........sss.......s                                                                                                                                                                                     [ 13%]
jax/tests/lax_autodiff_test.py ..................................F....F.FFFFFFFFFFFFFFFFFFFFF..................................................................................................................................................................................... [ 15%]
.................................................................................................................................................................................................................................................................................. [ 17%]
............................                                                                                                                                                                                                                                                       [ 18%]
jax/tests/lax_control_flow_test.py .....................s......................................................................................................................................................................................................................... [ 20%]
...............................                                                                                                                                                                                                                                                    [ 20%]
jax/tests/lax_numpy_einsum_test.py ..............................................................................................................................................                                                                                                  [ 21%]
jax/tests/lax_numpy_indexing_test.py ............................................................................................................................................................................................................................................. [ 23%]
.................................................................................                                                                                                                                                                                                  [ 24%]
jax/tests/lax_numpy_test.py ..................................................................................s...............................................s.s................................................................................................................. [ 26%]
...................................................FFFFFFFFFF..............................................................................................................................................................................................................sssss.. [ 28%]
.......................................................................................................................................................................................s.........sssssss.......................................................................... [ 30%]
.................................................................................................................................................................................................................................................................................. [ 33%]
.................................................................................................................................................................................................................................................................................. [ 35%]
.................................................................................................................................................................................................................................................................................. [ 37%]
.................................................................................................................................................................................................................................................................................. [ 39%]
.................................................................................................................................................................................................................................................................................. [ 42%]
.......................ss..............................................................FFFFFFFFFF................................................................................................................................................................................. [ 44%]
.................................................................................................................................................................................................................................................................................. [ 46%]
..........................sssss......................................................................................................................................s............................................................................................................ [ 49%]
.......................................................................................................................................................s..ss..s..s..ss..s.s.sss.ss.s.s..ss..s..s..ss..s..s..ss..s..s..ss..s..s.s..s....s..ss..s..ss.ss..s..s..s........s...s..s..s [ 51%]
......s..s........s.s..s....s.s..s..s..ss..s..s..ss..s..s.s..s....s..ss..s..s..ss..s....s.s..s....s.s..s.s.sss.ss.s.s.s..ss..s.sss.ss.s.s..ss..s..s..ss..s..s..ss..s....s.s..s.s.sss.ss.s.s.sss.ss.s..s.s..s.s.sss.ss.ss.sss.ss.s.s.s..ss.....s.s..s.ss....s..ss.s.sss.ss.ssss.ss. [ 53%]
s.s.s..s....s.s..s......s...s..s..ss..s..s..ss..s..s..ss..s.......s..s..s..ss.s.ss..ss....s.s.sss.ss.ssss.ss.s.s..ss..s..s..ss..s..s..ss..s..s..ss..s..s.s..s....s.s..s...s.sss.ss.s.s..ss..s.s.sss.ss.ss.sss.ss.ss.sss.ss.ss.sss.ss.ss.sss.ss.s.s.s..ss.....s.s..s.s.sss.ss.ss... [ 56%]
.s.....s.s..s...s.sss.ss.ss....s....s.sss.ss.s...s.s..s....s.s..s..s..ss..s..s.s..ss..ss..ss....ss..ss..s.s....s.......s.s..s..s..ss..s..s..ss..s.ssssssssss.s..ss..s.ss..ss..s.ssssssss...s..ss..s..s..ss..s.s.sss.ss.s...s.s..s.                                                 [ 57%]
jax/tests/lax_numpy_vectorize_test.py ............................                                                                                                                                                                                                                 [ 58%]
jax/tests/lax_scipy_sparse_test.py ssssssssss......ssssssssssssssssss.......Fatal Python error: Segmentation fault

Thread 0x00007fb2a68ec740 (most recent call first):
  File "/home/skyewm/jax/jax/interpreters/xla.py", line 356 in backend_compile
  File "/home/skyewm/jax/jax/interpreters/xla.py", line 292 in xla_primitive_callable
  File "/home/skyewm/jax/jax/_src/util.py", line 191 in cached
  File "/home/skyewm/jax/jax/_src/util.py", line 198 in wrapper
  File "/home/skyewm/jax/jax/interpreters/xla.py", line 242 in apply_primitive
  File "/home/skyewm/jax/jax/core.py", line 628 in process_primitive
  File "/home/skyewm/jax/jax/core.py", line 282 in bind
  File "/home/skyewm/jax/jax/core.py", line 363 in eval_jaxpr
  File "/home/skyewm/jax/jax/core.py", line 152 in jaxpr_as_fun
  File "/home/skyewm/jax/jax/_src/lax/control_flow.py", line 2234 in _custom_linear_solve_impl
  File "/home/skyewm/jax/jax/core.py", line 628 in process_primitive
  File "/home/skyewm/jax/jax/core.py", line 282 in bind
  File "/home/skyewm/jax/jax/_src/lax/control_flow.py", line 2224 in custom_linear_solve
  File "/home/skyewm/jax/jax/_src/scipy/sparse/linalg.py", line 622 in gmres
  File "/home/skyewm/jax/tests/lax_scipy_sparse_test.py", line 271 in test_gmres_on_identity_system
[...]

Looks like there are other test failures too, but they didn't print due to the segfault. cc @hawkinsp

hawkinsp commented 3 years ago

I built jaxlib with debug symbols and grabbed a stack trace:

#0  raise (sig=<optimized out>) at ../sysdeps/unix/sysv/linux/raise.c:51
#1  <signal handler called>
#2  0x00007fd372341bd8 in __GI___pthread_timedjoin_ex (threadid=140545868896000, thread_return=0x0,
    abstime=0x0, block=true) at pthread_join_common.c:40
#3  0x00007fd36c7339bf in blas_thread_shutdown_ ()
   from /home/phawkins/.pyenv/versions/py3.8.6/lib/python3.8/site-packages/numpy/core/../../numpy.libs/libopenblasp-r0-5bebc122.3.13.dev.so
#4  0x00007fd37188789a in __libc_fork () at ../sysdeps/nptl/fork.c:96
#5  0x00007fd345661070 in tensorflow::SubProcess::Start (this=0x7fd1fdffd4a0)
    at external/org_tensorflow/tensorflow/core/platform/default/subprocess.cc:210
#6  0x00007fd345501393 in stream_executor::CompileGpuAsm (cc_major=7, cc_minor=0,
    ptx_contents=0x7fd1c809c670 "//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 6.0\n.target sm_70\n.address_size 64\n\n\t// .globl\tadd_16\n.extern .global .align 64 .b8 buffer_for_constant_219[4];\n\n.visible .entry add_16(\n\t.param .u6"..., options=...)
    at external/org_tensorflow/tensorflow/stream_executor/gpu/asm_compiler.cc:249
#7  0x00007fd345500163 in stream_executor::CompileGpuAsm (device_ordinal=0,
    ptx_contents=0x7fd1c809c670 "//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 6.0\n.target sm_70\n.address_size 64\n\n\t// .globl\tadd_16\n.extern .global .align 64 .b8 buffer_for_constant_219[4];\n\n.visible .entry add_16(\n\t.param .u6"..., options=...)
    at external/org_tensorflow/tensorflow/stream_executor/gpu/asm_compiler.cc:150
#8  0x00007fd33fb124e2 in xla::gpu::NVPTXCompiler::CompileGpuAsmOrGetCachedResult (
    this=0x563d97c2cba0, stream_exec=0x7fd218006a60,
    ptx="//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 6.0\n.target sm_70\n.address_size 64\n\n\t// .globl\tadd_16\n.extern .global .align 64 .b8 buffer_for_constant_219[4];\n\n.visible .entry add_16(\n\t.param .u6"..., cc_major=7, cc_minor=0, hlo_module_config=..., relocatable=true)
    at external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:377
#9  0x00007fd33fb11dc4 in xla::gpu::NVPTXCompiler::CompileTargetBinary (this=0x563d97c2cba0,
    module_config=..., llvm_module=0x7fd1c8020050, gpu_version=..., stream_exec=0x7fd218006a60,
    relocatable=true, debug_module=0x563dc0e82fa0)

This is probably an instance of OpenBlas (used by NumPy) misbehaving in a process that also fork()s.

hawkinsp commented 3 years ago

https://github.com/xianyi/OpenBLAS/pull/3111 should fix the underlying OpenBLAS problem, I think. I can't confirm that 100% because I was unable to reproduce the original issue with a self-built OpenBLAS, only the one that is bundled with NumPy.

However, since it will take some time for any OpenBLAS fix to make it into a NumPy release and for that fix to make it to users, I'll also look into avoiding calling pthread_atfork handlers when spawning a subprocess.

hawkinsp commented 3 years ago

With an upcoming fix to TensorFlow to avoid calling pthread_atfork() handlers, I am down to only two failures:

=========================== short test summary info ============================
FAILED tests/lax_numpy_test.py::NumpySignaturesTest::testWrappedSignaturesMatch
FAILED tests/pmap_test.py::PmapTest::test_replicate_backend - ValueError: com...
========== 2 failed, 10854 passed, 1198 skipped in 955.43s (0:15:55) ===========

The former is related to NumPy 1.20 on my machine and unrelated to GPU specifically.

The latter I am unsure: it doesn't appear when I run that one file in isolation. So I'm guessing it must have something to do with pytest running a particular combination of tests on one worker.

hawkinsp commented 3 years ago

The pmap_test.py failure looks like this:

self = <pmap_test.PmapTest testMethod=test_replicate_backend>

    @jtu.skip_on_devices("cpu")
    def test_replicate_backend(self):
      # https://github.com/google/jax/issues/4223
      def fn(indices):
        return jnp.equal(indices, jnp.arange(3)).astype(jnp.float32)
      mapped_fn = jax.pmap(fn, axis_name='i', backend='cpu')
      mapped_fn = jax.pmap(mapped_fn, axis_name='j', backend='cpu')
      indices = np.array([[[2], [1]], [[0], [0]]])
>     mapped_fn(indices)  # doesn't crash
E     jax._src.traceback_util.FilteredStackTrace: ValueError: compiling computation that requires 4 logical devices, but only 1 XLA devices are available (num_replicas=4, num_partitions=1)
E
E     The stack trace above excludes JAX-internal frames.
E     The following is the original exception that occurred, unmodified.
E
E     --------------------

tests/pmap_test.py:1641: FilteredStackTrace
hawkinsp commented 3 years ago

I think all the issues identified here are already fixed at head.