Closed skye closed 3 years ago
I built jaxlib
with debug symbols and grabbed a stack trace:
#0 raise (sig=<optimized out>) at ../sysdeps/unix/sysv/linux/raise.c:51
#1 <signal handler called>
#2 0x00007fd372341bd8 in __GI___pthread_timedjoin_ex (threadid=140545868896000, thread_return=0x0,
abstime=0x0, block=true) at pthread_join_common.c:40
#3 0x00007fd36c7339bf in blas_thread_shutdown_ ()
from /home/phawkins/.pyenv/versions/py3.8.6/lib/python3.8/site-packages/numpy/core/../../numpy.libs/libopenblasp-r0-5bebc122.3.13.dev.so
#4 0x00007fd37188789a in __libc_fork () at ../sysdeps/nptl/fork.c:96
#5 0x00007fd345661070 in tensorflow::SubProcess::Start (this=0x7fd1fdffd4a0)
at external/org_tensorflow/tensorflow/core/platform/default/subprocess.cc:210
#6 0x00007fd345501393 in stream_executor::CompileGpuAsm (cc_major=7, cc_minor=0,
ptx_contents=0x7fd1c809c670 "//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 6.0\n.target sm_70\n.address_size 64\n\n\t// .globl\tadd_16\n.extern .global .align 64 .b8 buffer_for_constant_219[4];\n\n.visible .entry add_16(\n\t.param .u6"..., options=...)
at external/org_tensorflow/tensorflow/stream_executor/gpu/asm_compiler.cc:249
#7 0x00007fd345500163 in stream_executor::CompileGpuAsm (device_ordinal=0,
ptx_contents=0x7fd1c809c670 "//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 6.0\n.target sm_70\n.address_size 64\n\n\t// .globl\tadd_16\n.extern .global .align 64 .b8 buffer_for_constant_219[4];\n\n.visible .entry add_16(\n\t.param .u6"..., options=...)
at external/org_tensorflow/tensorflow/stream_executor/gpu/asm_compiler.cc:150
#8 0x00007fd33fb124e2 in xla::gpu::NVPTXCompiler::CompileGpuAsmOrGetCachedResult (
this=0x563d97c2cba0, stream_exec=0x7fd218006a60,
ptx="//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 6.0\n.target sm_70\n.address_size 64\n\n\t// .globl\tadd_16\n.extern .global .align 64 .b8 buffer_for_constant_219[4];\n\n.visible .entry add_16(\n\t.param .u6"..., cc_major=7, cc_minor=0, hlo_module_config=..., relocatable=true)
at external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:377
#9 0x00007fd33fb11dc4 in xla::gpu::NVPTXCompiler::CompileTargetBinary (this=0x563d97c2cba0,
module_config=..., llvm_module=0x7fd1c8020050, gpu_version=..., stream_exec=0x7fd218006a60,
relocatable=true, debug_module=0x563dc0e82fa0)
This is probably an instance of OpenBlas (used by NumPy) misbehaving in a process that also fork()
s.
https://github.com/xianyi/OpenBLAS/pull/3111 should fix the underlying OpenBLAS problem, I think. I can't confirm that 100% because I was unable to reproduce the original issue with a self-built OpenBLAS, only the one that is bundled with NumPy.
However, since it will take some time for any OpenBLAS fix to make it into a NumPy release and for that fix to make it to users, I'll also look into avoiding calling pthread_atfork
handlers when spawning a subprocess.
With an upcoming fix to TensorFlow to avoid calling pthread_atfork()
handlers, I am down to only two failures:
=========================== short test summary info ============================
FAILED tests/lax_numpy_test.py::NumpySignaturesTest::testWrappedSignaturesMatch
FAILED tests/pmap_test.py::PmapTest::test_replicate_backend - ValueError: com...
========== 2 failed, 10854 passed, 1198 skipped in 955.43s (0:15:55) ===========
The former is related to NumPy 1.20 on my machine and unrelated to GPU specifically.
The latter I am unsure: it doesn't appear when I run that one file in isolation. So I'm guessing it must have something to do with pytest
running a particular combination of tests on one worker.
The pmap_test.py
failure looks like this:
self = <pmap_test.PmapTest testMethod=test_replicate_backend>
@jtu.skip_on_devices("cpu")
def test_replicate_backend(self):
# https://github.com/google/jax/issues/4223
def fn(indices):
return jnp.equal(indices, jnp.arange(3)).astype(jnp.float32)
mapped_fn = jax.pmap(fn, axis_name='i', backend='cpu')
mapped_fn = jax.pmap(mapped_fn, axis_name='j', backend='cpu')
indices = np.array([[[2], [1]], [[0], [0]]])
> mapped_fn(indices) # doesn't crash
E jax._src.traceback_util.FilteredStackTrace: ValueError: compiling computation that requires 4 logical devices, but only 1 XLA devices are available (num_replicas=4, num_partitions=1)
E
E The stack trace above excludes JAX-internal frames.
E The following is the original exception that occurred, unmodified.
E
E --------------------
tests/pmap_test.py:1641: FilteredStackTrace
I think all the issues identified here are already fixed at head.
I'm unable to run all unit tests with
jaxlib==0.1.60+cuda111
. I suspect this is an issue for all GPU builds.Looks like there are other test failures too, but they didn't print due to the segfault. cc @hawkinsp