jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.99k stars 2.75k forks source link

RuntimeError: INTERNAL: Core halted unexpectedly: No error message available as no compiler metadata was provided. #9642

Open ayaka14732 opened 2 years ago

ayaka14732 commented 2 years ago

The script runs normally on a Cloud TPU v2-8 VM before, but now it shows an error:

import os
os.environ['XLA_PYTHON_CLIENT_ALLOCATOR'] = 'platform'

import jax
import subprocess
np = jax.numpy

devices = jax.devices()

def show_mem(result: np.ndarray) -> str:
    result.block_until_ready()
    jax.profiler.save_device_memory_profile('/tmp/memory.prof')
    return subprocess.run(['go', 'tool', 'pprof', '-tags', '/tmp/memory.prof'], stdout=subprocess.PIPE, stderr=subprocess.DEVNULL).stdout.decode('utf-8')

def largest_v2() -> np.ndarray:
    return np.zeros((1024, 1024, 957, 2), dtype=np.float32)

# print(show_mem(largest_v2()))

print(show_mem(jax.jit(largest_v2, device=devices[1])()))
print(show_mem(jax.jit(largest_v2, device=devices[2])()))

Error message:

$ python test_memory.py
 device: Total 7.5GB
         7.5GB (  100%): TPU_1(process=0,(0,0,0,1))

 kind: Total 7.5GB
         7.5GB (  100%): buffer
       -1.0B (1.2e-08%): executable

2022-02-19 21:29:36.266338: W external/org_tensorflow/tensorflow/stream_executor/stream.cc:275] Error blocking host until done in stream destructor: INTERNAL: stream did not block host until done; was already in an error state
Traceback (most recent call last):
  File "test_memory.py", line 22, in <module>
    print(show_mem(jax.jit(largest_v2, device=devices[2])()))
  File "test_memory.py", line 12, in show_mem
    result.block_until_ready()
RuntimeError: INTERNAL: Core halted unexpectedly: No error message available as no compiler metadata was provided.
2022-02-19 21:29:36.404625: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/local_device_state.cc:74] Error when closing device: INTERNAL: Core halted unexpectedly: No error message available as no compiler metadata was provided.
2022-02-19 21:29:36.404907: W external/org_tensorflow/tensorflow/stream_executor/stream.cc:275] Error blocking host until done in stream destructor: INTERNAL: stream did not block host until done; was already in an error state
2022-02-19 21:29:36.405494: W external/org_tensorflow/tensorflow/stream_executor/stream.cc:275] Error blocking host until done in stream destructor: INTERNAL: stream did not block host until done; was already in an error state

Library versions:

$ pip list | grep jax
jax                      0.3.1
jaxlib                   0.3.0
young-geng commented 2 years ago

I'm also getting this error on a TPU v3-8 VM. However it seems that this error is not triggered if I don't transfer the computation results to CPU.

skye commented 2 years ago

Hm, I'm not able to repro this:

$ python3 test_memory.py 
 device: Total 7.5GB
         7.5GB (  100%): TPU_1(process=0,(0,0,0,1))

 kind: Total 7.5GB
         7.5GB (  100%): buffer
       -1.0B (1.2e-08%): executable

 device: Total 7.5GB
         7.5GB (  100%): TPU_2(process=0,(1,0,0,0))

 kind: Total 7.5GB
         7.5GB (  100%): buffer
       -2.0B (2.5e-08%): executable

This is running on a fresh v2-8 (no virtualenv or anything, just jax 0.3.1 and jaxlib 0.3.0). Can you confirm you have this libtpu version?

$ pip list | grep libtpu
libtpu-nightly           0.1.dev20220128

Can you also rerun your repro and send me /tmp/tpu_logs/tpu_driver.INFO? (You can stick it somewhere visible from this issue or email me directly)

young-geng commented 2 years ago

I'm using different code but encountered the same error message. Here's my Jax and libtpu version:

$ pip list | grep libtpu
libtpu-nightly                    0.1.dev20220128

$ pip list | grep jax
jax                               0.3.1
jaxlib                            0.3.0

I've attached my tpu_driver.INFO in this gist.

ayaka14732 commented 2 years ago

I am using virtualenv

jax                          0.3.1
jaxlib                       0.3.0
libtpu-nightly               0.1.dev20220128

tpu_driver.INFO (Ubuntu Pastebin, expires in one week)

skye commented 2 years ago

Thanks for this info. I've filed an internal issue with the TPU runtime team (b/222577401).

Can you try with the following versions and let me know if you still get the same error? This will help pin down where the issue is. I suggest running pip uninstall -y jax jaxlib libtpu-nightly between installs to make sure the right versions are then installed.

  1. pip install jax==0.3.1 jaxlib==0.3.0 libtpu-nightly==0.1.dev20220218 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
  2. pip install jax[tpu]==0.2.28 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html (previous jax version)
  3. pip install jax[tpu]==0.2.21 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html (arbitrary older version)

@young-geng are you able to provide your code? If not, are you also spawning a subprocess?

@ayaka14732 do you still get the error if you remove the subprocess call? (e.g. just return "test" or something)

Thanks again for your help debugging this! Hopefully we can get to the bottom of this soon. Feel free to provide any information as you have it.

ayaka14732 commented 2 years ago

Oh, you've mentioned libtpu-nightly==0.1.dev20220218. I just wanted to report that there is another issue with this version, but I don't know where to report.

This version does not work at all (both for JAX and PyTorch XLA):

$ python test.py
https://symbolize.stripped_domain/r/?trace=7fcf5129603b,7fcf512960bf,7fcf4405434b,7fcf44055176,7fcf42732dbb,7fcf4db93c22,7fcf4bb4a196,7fcf4bb400c7,7fcf4bb37305,7fcf4b1
04e6a,7fcf4b113705,7fcf49573ba1,7fcf492ba795,7fcf492ba9db,7fcf492972f9,5f5e78,903aff&map=2a1715f2c413d37900cabb303a7e0c76:7fcf48d5c000-7fcf4f8378e1,8de474c84c849036fc7
88bcde7d9ce73:7fcf39a9b000-7fcf4852e680
*** SIGABRT received by PID 545434 (TID 545434) on cpu 25 from PID 545434; stack trace: ***
PC: @     0x7fcf5129603b  (unknown)  raise
    @     0x7fcf481c147a        992  (unknown)
    @     0x7fcf512960c0  262631056  (unknown)
    @     0x7fcf4405434c        128  (unknown)
    @     0x7fcf44055177        192  (unknown)
    @     0x7fcf42732dbc       1328  TpuCompiler_RunHloPasses
    @     0x7fcf4db93c23       4160  xla::(anonymous namespace)::TpuCompiler::RunHloPasses()
    @     0x7fcf4bb4a197        544  xla::Service::BuildExecutable()
    @     0x7fcf4bb400c8       1056  xla::LocalService::CompileExecutables()
    @     0x7fcf4bb37306       2432  xla::LocalClient::Compile()
    @     0x7fcf4b104e6b        832  xla::PjRtStreamExecutorClient::Compile()
    @     0x7fcf4b113706       1088  xla::PjRtStreamExecutorClient::Compile()
    @     0x7fcf49573ba2       1040  xla::PyClient::CompileMlir()
    @     0x7fcf492ba796       1696  pybind11::detail::argument_loader<>::call_impl<>()
    @     0x7fcf492ba9dc        208  pybind11::cpp_function::initialize<>()::{lambda()#3}::operator()()
    @     0x7fcf492972fa        640  pybind11::cpp_function::dispatcher()
    @           0x5f5e79  (unknown)  PyCFunction_Call
    @           0x903b00  (unknown)  (unknown)
https://symbolize.stripped_domain/r/?trace=7fcf5129603b,7fcf481c1479,7fcf512960bf,7fcf4405434b,7fcf44055176,7fcf42732dbb,7fcf4db93c22,7fcf4bb4a196,7fcf4bb400c7,7fcf4bb
37305,7fcf4b104e6a,7fcf4b113705,7fcf49573ba1,7fcf492ba795,7fcf492ba9db,7fcf492972f9,5f5e78,903aff&map=2a1715f2c413d37900cabb303a7e0c76:7fcf48d5c000-7fcf4f8378e1,8de474
c84c849036fc788bcde7d9ce73:7fcf39a9b000-7fcf4852e680
E0304 08:36:11.810099  545434 coredump_hook.cc:365] RAW: Remote crash data gathering hook invoked.
E0304 08:36:11.810108  545434 coredump_hook.cc:411] RAW: Skipping coredump since rlimit was 0 at process start.                                                        E0304 08:36:11.810123  545434 client.cc:221] RAW: Coroner client retries enabled (b/136286901), will retry for up to 30 sec.
E0304 08:36:11.810127  545434 coredump_hook.cc:473] RAW: Sending fingerprint to remote end.
E0304 08:36:11.810138  545434 coredump_socket.cc:124] RAW: Stat failed errno=2 on socket /var/google/services/logmanagerd/remote_coredump.socket
E0304 08:36:11.810145  545434 coredump_hook.cc:477] RAW: Cannot send fingerprint to Coroner: [NOT_FOUND] Missing crash reporting socket. Is the listener running?
E0304 08:36:11.810149  545434 coredump_hook.cc:550] RAW: Discarding core.
E0304 08:36:12.046387  545434 process_state.cc:771] RAW: Raising signal 6 with default behavior
[1]    545434 abort (core dumped)  python test.py
ayaka14732 commented 2 years ago
ayaka14732 commented 2 years ago

If I remove the subprocess call I won't get the error. It works well.

skye commented 2 years ago

Thanks for the speedy replies!

For the libtpu-nightly==0.1.dev20220218 SIGABRT failure, please feel free to report that kind of thing here! In this case, we're already aware of the issue and should have a fixed libtpu-nightly out soon (apologies for suggesting you try it, I forgot about this issue).

Thanks also for isolating where the Core halted unexpectedly error began. This will help with debugging.

young-geng commented 2 years ago

Unfortunately I'm not able to share my code at this moment and I can confirm that it also fails with Jax 0.2.28.

Regarding subprocesses, indeed I was using multiprocessing to spawn child processes to do data processing on CPU. I managed to find a workaround for this problem by not importing Jax and doing any Jax operations until I spawn all the child processes I need. In that way the same Jax training code works without a problem. It would be great to get this problem resolved since multiprocess data preprocessing is a common practice in training large scale models.

Thanks for helping me with this problem!

froody commented 2 years ago

I'm seeing this with tf2.8 on a tpu-v2-8

theyorubayesian commented 2 months ago

I get this issue with 0.1.dev20240617 and jax==0.4.30 when this section of my code runs on TPU v4-32:

epochs = tqdm(args, kwargs)

...
...
....

train_metric = jax_utils.unreplicate(train_metric)
epochs.write(f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']} | Learning rate: {train_metric['learning_rate']})")