Open ayaka14732 opened 2 years ago
I'm also getting this error on a TPU v3-8 VM. However it seems that this error is not triggered if I don't transfer the computation results to CPU.
Hm, I'm not able to repro this:
$ python3 test_memory.py
device: Total 7.5GB
7.5GB ( 100%): TPU_1(process=0,(0,0,0,1))
kind: Total 7.5GB
7.5GB ( 100%): buffer
-1.0B (1.2e-08%): executable
device: Total 7.5GB
7.5GB ( 100%): TPU_2(process=0,(1,0,0,0))
kind: Total 7.5GB
7.5GB ( 100%): buffer
-2.0B (2.5e-08%): executable
This is running on a fresh v2-8 (no virtualenv or anything, just jax 0.3.1 and jaxlib 0.3.0). Can you confirm you have this libtpu version?
$ pip list | grep libtpu
libtpu-nightly 0.1.dev20220128
Can you also rerun your repro and send me /tmp/tpu_logs/tpu_driver.INFO
? (You can stick it somewhere visible from this issue or email me directly)
I'm using different code but encountered the same error message. Here's my Jax and libtpu version:
$ pip list | grep libtpu
libtpu-nightly 0.1.dev20220128
$ pip list | grep jax
jax 0.3.1
jaxlib 0.3.0
I've attached my tpu_driver.INFO in this gist.
I am using virtualenv
jax 0.3.1
jaxlib 0.3.0
libtpu-nightly 0.1.dev20220128
tpu_driver.INFO (Ubuntu Pastebin, expires in one week)
Thanks for this info. I've filed an internal issue with the TPU runtime team (b/222577401).
Can you try with the following versions and let me know if you still get the same error? This will help pin down where the issue is. I suggest running pip uninstall -y jax jaxlib libtpu-nightly
between installs to make sure the right versions are then installed.
pip install jax==0.3.1 jaxlib==0.3.0 libtpu-nightly==0.1.dev20220218 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
pip install jax[tpu]==0.2.28 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
(previous jax version)pip install jax[tpu]==0.2.21 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
(arbitrary older version)@young-geng are you able to provide your code? If not, are you also spawning a subprocess?
@ayaka14732 do you still get the error if you remove the subprocess call? (e.g. just return "test"
or something)
Thanks again for your help debugging this! Hopefully we can get to the bottom of this soon. Feel free to provide any information as you have it.
Oh, you've mentioned libtpu-nightly==0.1.dev20220218
. I just wanted to report that there is another issue with this version, but I don't know where to report.
This version does not work at all (both for JAX and PyTorch XLA):
$ python test.py
https://symbolize.stripped_domain/r/?trace=7fcf5129603b,7fcf512960bf,7fcf4405434b,7fcf44055176,7fcf42732dbb,7fcf4db93c22,7fcf4bb4a196,7fcf4bb400c7,7fcf4bb37305,7fcf4b1
04e6a,7fcf4b113705,7fcf49573ba1,7fcf492ba795,7fcf492ba9db,7fcf492972f9,5f5e78,903aff&map=2a1715f2c413d37900cabb303a7e0c76:7fcf48d5c000-7fcf4f8378e1,8de474c84c849036fc7
88bcde7d9ce73:7fcf39a9b000-7fcf4852e680
*** SIGABRT received by PID 545434 (TID 545434) on cpu 25 from PID 545434; stack trace: ***
PC: @ 0x7fcf5129603b (unknown) raise
@ 0x7fcf481c147a 992 (unknown)
@ 0x7fcf512960c0 262631056 (unknown)
@ 0x7fcf4405434c 128 (unknown)
@ 0x7fcf44055177 192 (unknown)
@ 0x7fcf42732dbc 1328 TpuCompiler_RunHloPasses
@ 0x7fcf4db93c23 4160 xla::(anonymous namespace)::TpuCompiler::RunHloPasses()
@ 0x7fcf4bb4a197 544 xla::Service::BuildExecutable()
@ 0x7fcf4bb400c8 1056 xla::LocalService::CompileExecutables()
@ 0x7fcf4bb37306 2432 xla::LocalClient::Compile()
@ 0x7fcf4b104e6b 832 xla::PjRtStreamExecutorClient::Compile()
@ 0x7fcf4b113706 1088 xla::PjRtStreamExecutorClient::Compile()
@ 0x7fcf49573ba2 1040 xla::PyClient::CompileMlir()
@ 0x7fcf492ba796 1696 pybind11::detail::argument_loader<>::call_impl<>()
@ 0x7fcf492ba9dc 208 pybind11::cpp_function::initialize<>()::{lambda()#3}::operator()()
@ 0x7fcf492972fa 640 pybind11::cpp_function::dispatcher()
@ 0x5f5e79 (unknown) PyCFunction_Call
@ 0x903b00 (unknown) (unknown)
https://symbolize.stripped_domain/r/?trace=7fcf5129603b,7fcf481c1479,7fcf512960bf,7fcf4405434b,7fcf44055176,7fcf42732dbb,7fcf4db93c22,7fcf4bb4a196,7fcf4bb400c7,7fcf4bb
37305,7fcf4b104e6a,7fcf4b113705,7fcf49573ba1,7fcf492ba795,7fcf492ba9db,7fcf492972f9,5f5e78,903aff&map=2a1715f2c413d37900cabb303a7e0c76:7fcf48d5c000-7fcf4f8378e1,8de474
c84c849036fc788bcde7d9ce73:7fcf39a9b000-7fcf4852e680
E0304 08:36:11.810099 545434 coredump_hook.cc:365] RAW: Remote crash data gathering hook invoked.
E0304 08:36:11.810108 545434 coredump_hook.cc:411] RAW: Skipping coredump since rlimit was 0 at process start. E0304 08:36:11.810123 545434 client.cc:221] RAW: Coroner client retries enabled (b/136286901), will retry for up to 30 sec.
E0304 08:36:11.810127 545434 coredump_hook.cc:473] RAW: Sending fingerprint to remote end.
E0304 08:36:11.810138 545434 coredump_socket.cc:124] RAW: Stat failed errno=2 on socket /var/google/services/logmanagerd/remote_coredump.socket
E0304 08:36:11.810145 545434 coredump_hook.cc:477] RAW: Cannot send fingerprint to Coroner: [NOT_FOUND] Missing crash reporting socket. Is the listener running?
E0304 08:36:11.810149 545434 coredump_hook.cc:550] RAW: Discarding core.
E0304 08:36:12.046387 545434 process_state.cc:771] RAW: Raising signal 6 with default behavior
[1] 545434 abort (core dumped) python test.py
If I remove the subprocess call I won't get the error. It works well.
Thanks for the speedy replies!
For the libtpu-nightly==0.1.dev20220218
SIGABRT failure, please feel free to report that kind of thing here! In this case, we're already aware of the issue and should have a fixed libtpu-nightly out soon (apologies for suggesting you try it, I forgot about this issue).
Thanks also for isolating where the Core halted unexpectedly
error began. This will help with debugging.
Unfortunately I'm not able to share my code at this moment and I can confirm that it also fails with Jax 0.2.28.
Regarding subprocesses, indeed I was using multiprocessing to spawn child processes to do data processing on CPU. I managed to find a workaround for this problem by not importing Jax and doing any Jax operations until I spawn all the child processes I need. In that way the same Jax training code works without a problem. It would be great to get this problem resolved since multiprocess data preprocessing is a common practice in training large scale models.
Thanks for helping me with this problem!
I'm seeing this with tf2.8 on a tpu-v2-8
I get this issue with 0.1.dev20240617
and jax==0.4.30
when this section of my code runs on TPU v4-32:
epochs = tqdm(args, kwargs)
...
...
....
train_metric = jax_utils.unreplicate(train_metric)
epochs.write(f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']} | Learning rate: {train_metric['learning_rate']})")
The script runs normally on a Cloud TPU v2-8 VM before, but now it shows an error:
Error message:
Library versions: