google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.82k stars 2.73k forks source link

TPU not working #13260

Open chrisflesher opened 1 year ago

chrisflesher commented 1 year ago

Description

I created a new TPU VM and sucessfully used it but after stopping / starting the VM am getting the following error:

Python 3.8.10 (default, Jun 22 2022, 20:18:18) 
[GCC 9.4.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax
>>> jax.devices()
Traceback (most recent call last):
  File "/home/chris/.local/lib/python3.8/site-packages/jax/_src/lib/xla_bridge.py", line 333, in backends
    backend = _init_backend(platform)
  File "/home/chris/.local/lib/python3.8/site-packages/jax/_src/lib/xla_bridge.py", line 385, in _init_backend
    backend = factory()
  File "/home/chris/.local/lib/python3.8/site-packages/jax/_src/lib/xla_bridge.py", line 191, in tpu_client_timer_callback
    client = xla_client.make_tpu_client()
  File "/home/chris/.local/lib/python3.8/site-packages/jaxlib/xla_client.py", line 122, in make_tpu_client
    return _xla.get_tpu_client(
jaxlib.xla_extension.XlaRuntimeError: NOT_FOUND: No ba16c7433 device found.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/chris/.local/lib/python3.8/site-packages/jax/_src/lib/xla_bridge.py", line 483, in devices
    return get_backend(backend).devices()
  File "/home/chris/.local/lib/python3.8/site-packages/jax/_src/lib/xla_bridge.py", line 425, in get_backend
    return _get_backend_uncached(platform)
  File "/home/chris/.local/lib/python3.8/site-packages/jax/_src/lib/xla_bridge.py", line 409, in _get_backend_uncached
    bs = backends()
  File "/home/chris/.local/lib/python3.8/site-packages/jax/_src/lib/xla_bridge.py", line 350, in backends
    raise RuntimeError(err_msg)
RuntimeError: Unable to initialize backend 'tpu': NOT_FOUND: No ba16c7433 device found. (set JAX_PLATFORMS='' to automatically choose an available backend)

What jax/jaxlib version are you using?

jax 0.3.25[tpu], jaxlib 0.3.25

Which accelerator(s) are you using?

TPU

Additional system info

No response

NVIDIA GPU info

No response

skye commented 1 year ago

Did this work before you stopped/started it? Can you share /tmp/tpu_logs/tpu_driver.INFO from the TPU VM after getting this error? Also can you share the TPU creation command you used, or at least the TPU type?

chrisflesher commented 1 year ago

Yes, all jax TPU stuff seems to work fine until I stop / start the instance.

chris@t1v-n-4f57df69-w-0 ~ $ cat /tmp/tpu_logs/tpu_driver.INFO
Log file created at: 2022/11/15 21:16:04
Running on machine: t1v-n-4f57df69-w-0
Binary: Built on Nov 9 2022 02:08:18 (1667988479)
Binary: Built at cloud-tpus-runtime-release-tool@vqzj18.prod.google.com:/google/src/cloud/buildrabbit-username/buildrabbit-client/g3     
Binary: Built for gcc-4.X.Y-crosstool-v18-llvm-grtev4-k8
Log line format: [IWEF]mmdd hh:mm:ss.uuuuuu threadid file:line] msg
I1115 21:16:04.142176    3416 b295d63588a.cc:731] Linux version 5.15.0-1021-gcp (buildd@lcy02-amd64-008) (gcc (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0, GNU ld (GNU Binutils for Ubuntu) 2.34) #28~20.04.1-Ubuntu SMP Mon Oct 17 11:37:54 UTC 2022
I1115 21:16:04.142356    3416 b295d63588a.cc:798] Process id 3416
I1115 21:16:04.142365    3416 b295d63588a.cc:803] Current working directory /home/chris
I1115 21:16:04.142366    3416 b295d63588a.cc:805] Current timezone is UTC (currently UTC +00:00)
I1115 21:16:04.142368    3416 b295d63588a.cc:809] Built on Nov 9 2022 02:08:18 (1667988479)
I1115 21:16:04.142369    3416 b295d63588a.cc:810]  at cloud-tpus-runtime-release-tool@vqzj18.prod.google.com:/google/src/cloud/buildrabbit-username/buildrabbit-client/g3     
I1115 21:16:04.142370    3416 b295d63588a.cc:811]  as //learning/45eac/tfrc/executor:_libtpu.so
I1115 21:16:04.142370    3416 b295d63588a.cc:812]  for gcc-4.X.Y-crosstool-v18-llvm-grtev4-k8
I1115 21:16:04.142371    3416 b295d63588a.cc:815]  from changelist 487166979 with baseline 487129179 in a mint client based on __ar56t/branches/libtpu_runtime_release_branch/487129179.1/g3     
I1115 21:16:04.142372    3416 b295d63588a.cc:819] Build label: libtpu_202211090101_RC00
I1115 21:16:04.142373    3416 b295d63588a.cc:821] Build tool: Bazel, release r4rca-2022.10.26-2 (mainline @483757624)
I1115 21:16:04.142373    3416 b295d63588a.cc:822] Build target: 
I1115 21:16:04.142374    3416 b295d63588a.cc:829] Command line arguments:
I1115 21:16:04.142375    3416 b295d63588a.cc:831] argv[0]: './tpu_driver'
I1115 21:16:04.142377    3416 b295d63588a.cc:831] argv[1]: '--minloglevel=0'
I1115 21:16:04.142378    3416 b295d63588a.cc:831] argv[2]: '--stderrthreshold=3'
I1115 21:16:04.142378    3416 b295d63588a.cc:831] argv[3]: '--v=0'
I1115 21:16:04.142379    3416 b295d63588a.cc:831] argv[4]: '--vmodule='
I1115 21:16:04.142380    3416 b295d63588a.cc:831] argv[5]: '--log_dir=/tmp/tpu_logs'
I1115 21:16:04.142381    3416 b295d63588a.cc:831] argv[6]: '--max_log_size=1024'
I1115 21:16:04.142382    3416 b295d63588a.cc:831] argv[7]: '--enforce_kernel_ipv6_support=false'
I1115 21:16:04.142383    3416 b295d63588a.cc:831] argv[8]: '--tpu_use_tfrt=false'
I1115 21:16:04.142384    3416 b295d63588a.cc:831] argv[9]: '--2a886c8_chips_per_host_bounds=2,2,1'
I1115 21:16:04.142385    3416 b295d63588a.cc:831] argv[10]: '--2a886c8_host_bounds=1,1,1'
I1115 21:16:04.142385    3416 b295d63588a.cc:831] argv[11]: '--2a886c8_slice_builder_worker_port=8471'
I1115 21:16:04.142386    3416 b295d63588a.cc:831] argv[12]: '--2a886c8_slice_builder_worker_addresses=10.128.0.23:8471'
I1115 21:16:04.142737    3416 tpu_runtime_type_flags.cc:48] --tpu_use_tfrt is specified. Value: false
I1115 21:16:04.253008    3416 builtin.cc:16] 7edfa70aa11b3ffd6f: /memfile/cache_a5rtred_torus_config
I1115 21:16:04.253033    3416 builtin.cc:16] 7edfa70aa11b3ffd6f: /memfile/cache_a5rtred_torus_data
I1115 21:16:04.253039    3416 builtin.cc:16] 7edfa70aa11b3ffd6f: /memfile/cache_ici_resiliency_e0897fcce0_config
I1115 21:16:04.253043    3416 builtin.cc:16] 7edfa70aa11b3ffd6f: /memfile/cache_ici_resiliency_e0897fcce0_x_data
I1115 21:16:04.253052    3416 builtin.cc:16] 7edfa70aa11b3ffd6f: /memfile/cache_ici_resiliency_e0897fcce0_y_data
I1115 21:16:04.253059    3416 builtin.cc:16] 7edfa70aa11b3ffd6f: /memfile/cache_ici_resiliency_e0897fcce0_z_data
I1115 21:16:04.253073    3416 builtin.cc:16] 7edfa70aa11b3ffd6f: /memfile/routing_cache_files
I1115 21:16:04.253082    3416 builtin.cc:16] 7edfa70aa11b3ffd6f: /memfile/tpu_chip_config_memfile_default
I1115 21:16:04.253089    3416 builtin.cc:16] 7edfa70aa11b3ffd6f: /memfile/tpu_chip_config_memfile_inference
I1115 21:16:04.253094    3416 builtin.cc:16] 7edfa70aa11b3ffd6f: /memfile/tpu_chip_config_memfile_legacy
I1115 21:16:04.253096    3416 builtin.cc:16] 7edfa70aa11b3ffd6f: /memfile/tpu_chip_config_memfile_megacore_dense
I1115 21:16:04.253098    3416 builtin.cc:16] 7edfa70aa11b3ffd6f: /memfile/tpu_chip_config_memfile_megacore
I1115 21:16:04.253103    3416 builtin.cc:16] 7edfa70aa11b3ffd6f: /memfile/tpu_chip_config_memfile_megacore_inference
I1115 21:16:04.253105    3416 builtin.cc:16] 7edfa70aa11b3ffd6f: /memfile/tpu_chip_parts_memfile
I1115 21:16:04.271633    3416 coredump_hook.cc:742] Remote crash gathering hook installed.
I1115 21:16:04.297953    3416 logger.cc:296] Enabling threaded logging for severity WARNING
I1115 21:16:04.300211    3416 mlock.cc:214] mlock()-ed 4096 bytes for BuildID, using 1 syscalls.
I1115 21:16:04.426761    3416 global_config_custom.cc:114] GlobalConfig(grpc_verbosity) is set by environment variable. Please use flag(--grpc_verbosity) only.
I1115 21:16:04.485445    3416 init-domain.cc:91] Fiber init: default domain = futex, concurrency = 105, prefix = futex-default
W1115 21:16:04.622459    3416 tpu_version_flag.cc:57] No hardware is found. Using default TPU version: ba16c7433
I1115 21:16:04.647434    3416 flags_util.cc:244] Using 8471 from --2a886c8_slice_builder_worker_port as SliceBuilder worker service port.
skye commented 1 year ago

Hm can you check that /dev/accel0 exists on the VM? What happens if you try start/stopping it again?

chrisflesher commented 1 year ago
chris@t1v-n-4f57df69-w-0 ~ $ ls /dev/accel0
ls: cannot access '/dev/accel0': No such file or directory
chrisflesher commented 1 year ago

I tried restarting the instance and got the same result, /dev/accel0 is not found.

skye commented 1 year ago

Just to clarify, this is definitely a bug :) Can you confirm the command you used to create the VM? If possible, can you also share the project name and TPU name? (Feel free to email me instead of posting publicly)

chrisflesher commented 1 year ago

Here was the command I used to create:

gcloud compute tpus tpu-vm create ferrar --zone us-central1-c --accelerator-type v2-8 --version tpu-vm-base

Sent an e-mail with our project name just now.

skye commented 1 year ago

Thanks. We're still looking into this. In the meantime and if you're able, creating a new VM will likely resolve the issue (definitely let us know if it doesn't).

chrisflesher commented 1 year ago

Yes, creating a new VM seems to work until we stop / start the instance. We're using this as a workaround until this issue can be resolved.

melisandeteng commented 1 year ago

Hello, I seem to be running into the same issue. I did not stop/start the TPU VM though. It just seems my experiment stopped running and when I try to run it again, I have the following error, when it was working just fine before:


  File "/home/melisande/.local/lib/python3.9/site-packages/jax/_src/lib/xla_bridge.py", line 333, in backends
    backend = _init_backend(platform)
  File "/home/melisande/.local/lib/python3.9/site-packages/jax/_src/lib/xla_bridge.py", line 385, in _init_backend
    backend = factory()
  File "/home/melisande/.local/lib/python3.9/site-packages/jax/_src/lib/xla_bridge.py", line 191, in tpu_client_timer_callback
    client = xla_client.make_tpu_client()
  File "/home/melisande/.local/lib/python3.9/site-packages/jaxlib/xla_client.py", line 122, in make_tpu_client
    return _xla.get_tpu_client(
jaxlib.xla_extension.XlaRuntimeError: PERMISSION_DENIED: open(/dev/accel0): Operation not permitted: Operation not permitted; Couldn't open device: /dev/accel0; Unable to create Node RegisterInterface for node 0, config: device_path:    "/dev/accel0" mode: KERNEL debug_data_directory: "" dump_anomalies_only: true crash_in_debug_dump: false allow_core_dump: true; could not create driver instance

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/melisande/scenic-glc/scenic/projects/glc/main.py", line 44, in <module>
    devices = jax.local_devices()
  File "/home/melisande/.local/lib/python3.9/site-packages/jax/_src/lib/xla_bridge.py", line 515, in local_devices
    process_index = get_backend(backend).process_index()
  File "/home/melisande/.local/lib/python3.9/site-packages/jax/_src/lib/xla_bridge.py", line 425, in get_backend
    return _get_backend_uncached(platform)
  File "/home/melisande/.local/lib/python3.9/site-packages/jax/_src/lib/xla_bridge.py", line 409, in _get_backend_uncached
    bs = backends()
  File "/home/melisande/.local/lib/python3.9/site-packages/jax/_src/lib/xla_bridge.py", line 350, in backends
    raise RuntimeError(err_msg)
RuntimeError: Unable to initialize backend 'tpu': PERMISSION_DENIED: open(/dev/accel0): Operation not permitted: Operation not permitted; Couldn't open device: /dev/accel0; Unable to create Node RegisterInterface for node 0, config: device_path:        "/dev/accel0" mode: KERNEL debug_data_directory: "" dump_anomalies_only: true crash_in_debug_dump: false allow_core_dump: true; could not create driver instance (set JAX_PLATFORMS='' to automatically choose an available backend)```
skye commented 1 year ago

@melisandeteng just to verify, did you check that /dev/accel0 doesn't exist on the VM (e.g. ls /dev/accel0 doesn't show anything)? That error can also sometimes happen if another process is already using the TPU.

zw615 commented 1 year ago

@melisandeteng just to verify, did you check that /dev/accel0 doesn't exist on the VM (e.g. ls /dev/accel0 doesn't show anything)? That error can also sometimes happen if another process is already using the TPU.

@skye Hi, I am running into a similar issue here. The error message really looks like

jaxlib.xla_extension.XlaRuntimeError: ABORTED: The TPU is already in use by process with pid 1378725. Not attempting to load libtpu.so in this process.

RuntimeError: Unable to initialize backend 'tpu': ABORTED: The TPU is already in use by process with pid 1378725. Not attempting to load libtpu.so in this process. (set JAX_PLATFORMS='' to automatically choose an available backend)

And I have tried ls /dev/accel0 and found that directory does exist. I wonder how I can find the process that is using the TPU device and kill it? Probably some zombie process that was not killed completely.

ayaka14732 commented 1 year ago

@zeyuwang615 If the TPU is already in use by process with pid 137872, you should use

kill -9 137872

to kill it

zw615 commented 1 year ago

@ayaka14732 Yes, so my question really is how to find the process that is using TPU device, based on the existence of /dev/accel0.

skye commented 1 year ago

@zeyuwang615 the process ID of the process using the TPU is included in the error message (that's how @ayaka14732 knew to kill pid 137872 from your above example). You can also use sudo lsof -w /dev/accel0 on the command line to find processes using /dev/accel0 (one of the TPU chips).

Please file a new issue if you have more questions. I'd like to keep this thread focused on the missing /dev/accel0 driver.

vjeronymo2 commented 1 year ago

I'm also having the same problem as @melisandeteng, and I'm able to find /dev/accel0

ls -ltrh /dev/accel*
crwxrwxrwx 1 root root 121, 3 Mar 20 21:09 /dev/accel3
crwxrwxrwx 1 root root 121, 2 Mar 20 21:09 /dev/accel2
crwxrwxrwx 1 root root 121, 1 Mar 20 21:09 /dev/accel1
crwxrwxrwx 1 root root 121, 0 Mar 20 21:09 /dev/accel0

I even tried doing chmod 777 on them but I still get the error when doing accelerate test

stderr: 2023-03-21 17:31:48.125790: E tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc:601] PERMISSION_DENIED: open(/dev/accel0): Operation not permitted: Operation not permitted; Couldn't open device: /dev/accel0; Unable to create Node RegisterInterface for node 0, config: device_path:         "/dev/accel0" mode: KERNEL debug_data_directory: "" dump_anomalies_only: true crash_in_debug_dump: false allow_core_dump: true; could not create driver instance
stderr: 2023-03-21 17:31:48.125843: F tensorflow/compiler/xla/xla_client/xrt_local_service.cc:58] Non-OK-status: tensorflow::NewServer(server_def, &server_) status: PERMISSION_DENIED: open(/dev/accel0): Operation not permitted: Operation not permitted; Couldn't open device: /dev/accel0; Unable to create Node RegisterInterface for node 0, config: device_path:   "/dev/accel0" mode: KERNEL debug_data_directory: "" dump_anomalies_only: true crash_in_debug_dump: false allow_core_dump: true; could not create driver instance
stderr: https://symbolize.stripped_domain/r/?trace=7f8b3adb600b,7f8b3b0d341f,7f8a8c7aef1b,7f8a8cf0c1d5,7f8a8cf19f98,7f8a8cecc0f4,7f8a8ceccfe2,7f8b3b0d04de&map=04ceea301ec570e6abcf4ef3f089f0fde6516664:7f8a89c8f000-7f8a9d6e65e0
stderr: *** SIGABRT received by PID 269752 (TID 269752) on cpu 85 from PID 269752; stack trace: ***
stderr: PC: @     0x7f8b3adb600b  (unknown)  raise
stderr:     @     0x7f8a89063a1a       1152  (unknown)
stderr:     @     0x7f8b3b0d3420  (unknown)  (unknown)
mosmos6 commented 1 year ago

I have just encountered the same issue. Until I exited VM, everything was working fine with JAX 0.3.25. Besides, having the same issue with colab too with JAX 0.3.25.

KeremTurgutlu commented 1 year ago

Same here us-central1-a | v3-8 | v2-alpha-pod | TPU VM and followed the t5x installation steps. I was able to see dev/accel[0-3]. Stopping and restarting also didn't help.

python3 -c "import jax; print(jax.local_devices())"

Traceback (most recent call last):
  File "/home/keremturgutlu/t5_venv/lib/python3.9/site-packages/jax/_src/xla_bridge.py", line 420, in backends
    backend = _init_backend(platform)
  File "/home/keremturgutlu/t5_venv/lib/python3.9/site-packages/jax/_src/xla_bridge.py", line 473, in _init_backend
    backend = factory()
  File "/home/keremturgutlu/t5_venv/lib/python3.9/site-packages/jax/_src/xla_bridge.py", line 179, in tpu_client_timer_callback
    client = xla_client.make_tpu_client()
  File "/home/keremturgutlu/t5_venv/lib/python3.9/site-packages/jaxlib/xla_client.py", line 148, in make_tpu_client
    return make_tfrt_tpu_c_api_client()
  File "/home/keremturgutlu/t5_venv/lib/python3.9/site-packages/jaxlib/xla_client.py", line 106, in make_tfrt_tpu_c_api_client
    return _xla.get_c_api_client('tpu', options)
jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: TPU initialization failed: open(/dev/accel0): Operation not permitted: Operation not permitted; Couldn't open device: /dev/accel0; Unable to create Node RegisterInterface for node 0, config: device_path:   "/dev/accel0" mode: KERNEL debug_data_directory: "" dump_anomalies_only: true crash_in_debug_dump: false allow_core_dump: true; could not create driver instance
contrebande-labs commented 1 year ago

Same here. I don't know if it worked before I updated/upgraded the APT packages and rebooted the machine, but now I get:

python3 -c "import jax; print(jax.local_devices())"
Traceback (most recent call last):
  File "/home/vincent/charred/lib/python3.8/site-packages/jax/_src/xla_bridge.py", line 393, in backends
    backend = _init_backend(platform)
  File "/home/vincent/charred/lib/python3.8/site-packages/jax/_src/xla_bridge.py", line 445, in _init_backend
    backend = factory()
  File "/home/vincent/charred/lib/python3.8/site-packages/jax/_src/xla_bridge.py", line 187, in tpu_client_timer_callback
    client = xla_client.make_tpu_client()
  File "/home/vincent/charred/lib/python3.8/site-packages/jaxlib/xla_client.py", line 147, in make_tpu_client
    return _xla.get_tpu_client(
jaxlib.xla_extension.XlaRuntimeError: UNAVAILABLE: No TPU Platform available.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/home/vincent/charred/lib/python3.8/site-packages/jax/_src/xla_bridge.py", line 575, in local_devices
    process_index = get_backend(backend).process_index()
  File "/home/vincent/charred/lib/python3.8/site-packages/jax/_src/xla_bridge.py", line 485, in get_backend
    return _get_backend_uncached(platform)
  File "/home/vincent/charred/lib/python3.8/site-packages/jax/_src/xla_bridge.py", line 469, in _get_backend_uncached
    bs = backends()
  File "/home/vincent/charred/lib/python3.8/site-packages/jax/_src/xla_bridge.py", line 410, in backends
    raise RuntimeError(err_msg)
RuntimeError: Unable to initialize backend 'tpu': UNAVAILABLE: No TPU Platform available. (set JAX_PLATFORMS='' to automatically choose an available backend)

and there is no device either:

ls -la /dev/accel*
ls: cannot access '/dev/accel*': No such file or directory
stevenlimcorn commented 1 year ago

I have the same issue running Jax in Kaggle with TPU VM v3-8 accelerator. I got this error when running this script: https://github.com/huggingface/transformers/blob/main/examples/flax/language-modeling/run_bart_dlm_flax.py

Traceback (most recent call last):
  File "/kaggle/working/transformers/examples/flax/language-modeling/run_bart_dlm_flax.py", line 967, in <module>
    main()
  File "/kaggle/working/transformers/examples/flax/language-modeling/run_bart_dlm_flax.py", line 682, in main
    if has_tensorboard and jax.process_index() == 0:
  File "/usr/local/lib/python3.8/site-packages/jax/_src/xla_bridge.py", line 595, in process_index
    return get_backend(backend).process_index()
  File "/usr/local/lib/python3.8/site-packages/jax/_src/xla_bridge.py", line 485, in get_backend
    return _get_backend_uncached(platform)
  File "/usr/local/lib/python3.8/site-packages/jax/_src/xla_bridge.py", line 469, in _get_backend_uncached
    bs = backends()
  File "/usr/local/lib/python3.8/site-packages/jax/_src/xla_bridge.py", line 410, in backends
    raise RuntimeError(err_msg)
RuntimeError: Unable to initialize backend 'tpu': PERMISSION_DENIED: open(/dev/accel0): Operation not permitted: Operation not permitted; Couldn't open device: /dev/accel0; Unable to create Node RegisterInterface for node 0, config: device_path:    "/dev/accel0" mode: KERNEL debug_data_directory: "" dump_anomalies_only: true crash_in_debug_dump: false allow_core_dump: true; could not create driver instance (set JAX_PLATFORMS='' to automatically choose an available backend)

Any ideas on how to resolve this issue?

skye commented 1 year ago

Hi all, sorry for not responding earlier. Thank you all for the reports, all information is very useful while we debug this.

@stevenlimcorn are you able to check for the existence of /dev/accel0 on the notebook? Something like ! ls /dev/accel*.

It seems like there may be two issues here. One is that /dev/accel0 (and the rest of the TPU drivers) are missing altogether. That's when ls /dev/accel* on the the TPU VM doesn't find anything. The other is that the drivers are there, but running jax still gets Couldn't open device: /dev/accel0.

While we look into this more, the best I can suggest for now is:

I'm honestly not sure these mitigations will work, but they're my best guess for what might help. Please report back if you try any of these and whether they helped or not. That information will help us debug. Thanks, and sorry for the trouble.

owos commented 1 year ago

Hi, have you found a solution to this problem?

JudeDavis1 commented 1 year ago

I'm getting this as well.

AakashKumarNain commented 1 year ago

@skye @mattjj I faced the same issue today on Kaggle. There is an interesting thing that I found though. If you are importing TensorFlow before JAX, then somehow JAX can't initialize the TPU runtime.

I am sure that if I remove everything related to TF in the code, it would work. The problem is that we won't be able to use tfds for loading and processing datasets. Let me know if you need any more info from my side

Update: I figured the root cause at least for Kaggle Notebooks. Depending on who access the TPU first (TensorFlow or JAX), the other won't be able to initialize and use it. The bug is definitely related to TF as I am unable to hid the TPU system from the list of available devices. Gonna open an issue in TF repo

yoinked-h commented 1 year ago

have this issue too, XRT on pytorch; accelerate with TPU [max_workers=8] the accelX devices are there, get permission error i ran accelerate with --tpu_use_sudo, no luck

note, accelerate test came out all clean, so must be something to do with using all cores

djmango commented 12 months ago

Same issue on a TPU v4-8 VM, (started and stopped)

defdet commented 10 months ago

Faced the same thing with Pytorch XLA. Uninstalling Tensorflow seems to work

gauravbrills commented 9 months ago

faced same on kaggle TPU v4-8 as well , when using with keras nlp

skye commented 9 months ago

Ah, yes only one framework can use the TPU at a time, so if you import tensorflow before jax, jax won't be able to access the TPU. I suggest uninstalling tensorflow and reinstalling the CPU-only version, something like: !pip uninstall -y tensorflow !pip install tensorflow-cpu

That way you can use tensorflow for non-TPU functions like tf.data, and jax for the TPU parts.

@gauravbrills I'm not sure this workaround will work for Keras NLP. I believe there are some preliminary discussions about allowing multiple frameworks to access the TPU concurrently, but I don't have a timeline for this yet. I'm not very familiar with Keras -- are you trying to use Keras with a JAX backend?

skye commented 9 months ago

FYI @djherbis

gauravbrills commented 9 months ago

Ah, yes only one framework can use the TPU at a time, so if you import tensorflow before jax, jax won't be able to access the TPU. I suggest uninstalling tensorflow and reinstalling the CPU-only version, something like: !pip uninstall -y tensorflow !pip install tensorflow-cpu

That way you can use tensorflow for non-TPU functions like tf.data, and jax for the TPU parts.

@gauravbrills I'm not sure this workaround will work for Keras NLP. I believe there are some preliminary discussions about allowing multiple frameworks to access the TPU concurrently, but I don't have a timeline for this yet. I'm not very familiar with Keras -- are you trying to use Keras with a JAX backend?

yes was using os.environ["KERAS_BACKEND"] = "jax" .. shall I change this to tensorflow ? or try ur above approach . does using tensorflow-cpu mean my training will still run on tpu via jax ?

Thanks but the above recommendations did work

skye commented 9 months ago

does using tensorflow-cpu mean my training will still run on tpu via jax ?

I'm honestly not sure, since I'm very unfamiliar with Keras. You may wanna post in the Keras issue tracker (please cc me if you do).

Thanks but the above recommendations did work

Does this mean you tried installing tensorflow-cpu and it worked? Or didn't work?

gauravbrills commented 9 months ago

does using tensorflow-cpu mean my training will still run on tpu via jax ?

I'm honestly not sure, since I'm very unfamiliar with Keras. You may wanna post in the Keras issue tracker (please cc me if you do).

Thanks but the above recommendations did work

Does this mean you tried installing tensorflow-cpu and it worked? Or didn't work?

yes installing tensorflow-cpu worked .. and jax backend with keras was now utilizing tpu and blazingly fast thanks :D

fchollet commented 9 months ago

does using tensorflow-cpu mean my training will still run on tpu via jax ?

If you're using the JAX backend then it doesn't matter which version of TF you have installed. It will only be used for e.g. tf.data.

CaptainStiggz commented 7 months ago

@skye I'm having the same issue with the v5e TPU VM. ls /dev shows no accelerators at all. Also, this issue has happened each time I've tried to create a TPU VM. I've tried 3 or 4 times to create a new now with no luck. Tried both v5litepod-4 and v5litepod-1. Was mostly following the setup instructions listed here:

https://cloud.google.com/tpu/docs/run-calculation-jax

Is there a different version or process that's needed for v5 TPUs?

Create VM: gcloud alpha compute tpus tpu-vm create zdev-tpu5e --zone=us-west1-c --accelerator-type=v5litepod-4 --version=tpu-vm-base

Connect: gcloud compute tpus tpu-vm ssh zdev-tpu5e --zone=us-west1-c

Install JAX: pip3 install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

Test:

python3
Python 3.8.10 (default, Mar 15 2022, 12:22:08)
[GCC 9.4.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax
>>> jax.device_count()
tcmalloc: large alloc 4294967296 bytes == 0x4b6a000 @  0x7f65c1055680 0x7f65c1076824 0x7f65c1076b8a 0x7f651373934c 0x7f650c4017f2 0x7f650c3f0952 0x7f650c39af5f 0x7f650b828b2e 0x7f65b9a3aa59 0x7f65b96b0ad5 0x7f65b9687cbf 0x5f3989 0x5f3e1e 0x570674 0x6b2b5c 0x56b0ae 0x6b2b5c 0x570035 0x56939a 0x5f6a13 0x5f6242 0x66598d 0x5f3e1e 0x570674 0x5f6836 0x56b0ae 0x5f6836 0x56b0ae 0x56939a 0x5f6a13 0x56b0ae
Traceback (most recent call last):
  File "/home/zach/.local/lib/python3.8/site-packages/jax/_src/xla_bridge.py", line 593, in backends
    backend = _init_backend(platform)
  File "/home/zach/.local/lib/python3.8/site-packages/jax/_src/xla_bridge.py", line 647, in _init_backend
    backend = registration.factory()
  File "/home/zach/.local/lib/python3.8/site-packages/jax/_src/xla_bridge.py", line 200, in tpu_client_timer_callback
    client = xla_client.make_tpu_client()
  File "/home/zach/.local/lib/python3.8/site-packages/jaxlib/xla_client.py", line 169, in make_tpu_client
    return make_tfrt_tpu_c_api_client()
  File "/home/zach/.local/lib/python3.8/site-packages/jaxlib/xla_client.py", line 107, in make_tfrt_tpu_c_api_client
    return _xla.get_c_api_client('tpu', options)
jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: TPU initialization failed: No ba16c7433 device found.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/zach/.local/lib/python3.8/site-packages/jax/_src/xla_bridge.py", line 723, in device_count
    return int(get_backend(backend).device_count())
  File "/home/zach/.local/lib/python3.8/site-packages/jax/_src/xla_bridge.py", line 692, in get_backend
    return _get_backend_uncached(platform)
  File "/home/zach/.local/lib/python3.8/site-packages/jax/_src/xla_bridge.py", line 673, in _get_backend_uncached
    bs = backends()
  File "/home/zach/.local/lib/python3.8/site-packages/jax/_src/xla_bridge.py", line 609, in backends
    raise RuntimeError(err_msg)
RuntimeError: Unable to initialize backend 'tpu': UNKNOWN: TPU initialization failed: No ba16c7433 device found. (set JAX_PLATFORMS='' to automatically choose an available backend)

Logs:

Log file created at: 2024/02/06 23:21:41
Running on machine: t1v-n-6a2b8934-w-0
Binary: Built on Jun 22 2023 03:38:13 (1687430275)
Binary: Built at cloud-tpus-runtime-release-tool@odqe9.prod.google.com:/google/src/cloud/buildrabbit-username/buildrabbit-client/g3
Binary: Built for gcc-4.X.Y-crosstool-v18-llvm-grtev4-k8
Log line format: [IWEF]mmdd hh:mm:ss.uuuuuu threadid file:line] msg
I0206 23:21:41.495393    6454 b295d63588a.cc:726] Linux version 5.13.0-1027-gcp (buildd@lcy02-amd64-062) (gcc (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0, GNU ld (GNU Binutils for Ubuntu) 2.34) #32~20.04.1-Ubuntu SMP Thu May 26 10:53:08 UTC 2022
I0206 23:21:41.495570    6454 b295d63588a.cc:793] Process id 6454
I0206 23:21:41.495576    6454 b295d63588a.cc:798] Current working directory /home/zach
I0206 23:21:41.495577    6454 b295d63588a.cc:800] Current timezone is UTC (currently UTC +00:00)
I0206 23:21:41.495579    6454 b295d63588a.cc:804] Built on Jun 22 2023 03:38:13 (1687430275)
I0206 23:21:41.495580    6454 b295d63588a.cc:805]  at cloud-tpus-runtime-release-tool@odqe9.prod.google.com:/google/src/cloud/buildrabbit-username/buildrabbit-client/g3
I0206 23:21:41.495580    6454 b295d63588a.cc:806]  as //learning/45eac/tfrc/executor:_libtpu.so
I0206 23:21:41.495581    6454 b295d63588a.cc:807]  for gcc-4.X.Y-crosstool-v18-llvm-grtev4-k8
I0206 23:21:41.495582    6454 b295d63588a.cc:810]  from changelist 542473012 with baseline 542473012 in a mint client based on __ar56t/g3
I0206 23:21:41.495583    6454 b295d63588a.cc:814] Build label: libtpu_20230622_a_RC00
I0206 23:21:41.495583    6454 b295d63588a.cc:816] Build tool: Bazel, release r4rca-2023.06.15-1 (mainline @540397185)
I0206 23:21:41.495584    6454 b295d63588a.cc:817] Build target:
I0206 23:21:41.495585    6454 b295d63588a.cc:824] Command line arguments:
I0206 23:21:41.495585    6454 b295d63588a.cc:826] argv[0]: './tpu_driver'
I0206 23:21:41.495587    6454 b295d63588a.cc:826] argv[1]: '--minloglevel=0'
I0206 23:21:41.495588    6454 b295d63588a.cc:826] argv[2]: '--stderrthreshold=3'
I0206 23:21:41.495589    6454 b295d63588a.cc:826] argv[3]: '--v=0'
I0206 23:21:41.495590    6454 b295d63588a.cc:826] argv[4]: '--vmodule='
I0206 23:21:41.495590    6454 b295d63588a.cc:826] argv[5]: '--log_dir=/tmp/tpu_logs'
I0206 23:21:41.495591    6454 b295d63588a.cc:826] argv[6]: '--max_log_size=1024'
I0206 23:21:41.495592    6454 b295d63588a.cc:826] argv[7]: '--enforce_kernel_ipv6_support=false'
I0206 23:21:41.495593    6454 b295d63588a.cc:826] argv[8]: '--tpu_use_tfrt=false'
I0206 23:21:41.495594    6454 b295d63588a.cc:826] argv[9]: '--2a886c8_wrap=false,false,false'
I0206 23:21:41.495595    6454 b295d63588a.cc:826] argv[10]: '--2a886c8_chips_per_host_bounds=2,2,1'
I0206 23:21:41.495596    6454 b295d63588a.cc:826] argv[11]: '--2a886c8_host_bounds=1,1,1'
I0206 23:21:41.495597    6454 b295d63588a.cc:826] argv[12]: '--2a886c8_slice_builder_worker_port=8471'
I0206 23:21:41.495597    6454 b295d63588a.cc:826] argv[13]: '--2a886c8_slice_builder_worker_addresses=10.138.15.220:8471'
I0206 23:21:41.495598    6454 b295d63588a.cc:826] argv[14]: '--tpu_system_state_premapped_buffer_size=4294967296'
I0206 23:21:41.495599    6454 b295d63588a.cc:826] argv[15]: '--tpu_system_state_premapped_buffer_transfer_threshold_bytes=4294967296'
I0206 23:21:41.495666    6454 tpu_runtime_type_flags.cc:48] --tpu_use_tfrt is specified. Value: false
I0206 23:21:41.498918    6454 builtin.cc:17] 7edfa70aa11b3ffd6f: /memfile/cache_a5rtred_torus_config
I0206 23:21:41.498924    6454 builtin.cc:17] 7edfa70aa11b3ffd6f: /memfile/cache_a5rtred_torus_data
I0206 23:21:41.498928    6454 builtin.cc:17] 7edfa70aa11b3ffd6f: /memfile/cache_ici_resiliency_e0897fcce0_config
I0206 23:21:41.498931    6454 builtin.cc:17] 7edfa70aa11b3ffd6f: /memfile/cache_ici_resiliency_e0897fcce0_x_data
I0206 23:21:41.498940    6454 builtin.cc:17] 7edfa70aa11b3ffd6f: /memfile/cache_ici_resiliency_e0897fcce0_y_data
I0206 23:21:41.498947    6454 builtin.cc:17] 7edfa70aa11b3ffd6f: /memfile/cache_ici_resiliency_e0897fcce0_z_data
I0206 23:21:41.498952    6454 builtin.cc:17] 7edfa70aa11b3ffd6f: /memfile/cache_ici_resiliency_e0897fcce_config
I0206 23:21:41.498955    6454 builtin.cc:17] 7edfa70aa11b3ffd6f: /memfile/cache_ici_resiliency_e0897fcce_x_data
I0206 23:21:41.498964    6454 builtin.cc:17] 7edfa70aa11b3ffd6f: /memfile/cache_ici_resiliency_e0897fcce_y_data
I0206 23:21:41.498969    6454 builtin.cc:17] 7edfa70aa11b3ffd6f: /memfile/cache_ici_resiliency_e0897fcce_z_data
I0206 23:21:41.498974    6454 builtin.cc:17] 7edfa70aa11b3ffd6f: /memfile/routing_cache_files
I0206 23:21:41.498981    6454 builtin.cc:17] 7edfa70aa11b3ffd6f: /memfile/tpu_chip_config_memfile_default
I0206 23:21:41.498986    6454 builtin.cc:17] 7edfa70aa11b3ffd6f: /memfile/tpu_chip_config_memfile_inference
I0206 23:21:41.498990    6454 builtin.cc:17] 7edfa70aa11b3ffd6f: /memfile/tpu_chip_config_memfile_legacy
I0206 23:21:41.498992    6454 builtin.cc:17] 7edfa70aa11b3ffd6f: /memfile/tpu_chip_config_memfile_megacore_dense
I0206 23:21:41.498993    6454 builtin.cc:17] 7edfa70aa11b3ffd6f: /memfile/tpu_chip_config_memfile_megacore
I0206 23:21:41.498999    6454 builtin.cc:17] 7edfa70aa11b3ffd6f: /memfile/tpu_chip_config_memfile_megacore_inference
I0206 23:21:41.499001    6454 builtin.cc:17] 7edfa70aa11b3ffd6f: /memfile/tpu_chip_parts_memfile
I0206 23:21:41.499005    6454 builtin.cc:17] 7edfa70aa11b3ffd6f: /memfile/vlc_architectural_resources_memfile
I0206 23:21:41.499037    6454 coredump_hook.cc:727] Remote crash gathering hook installed.
I0206 23:21:41.500676    6454 logger.cc:296] Enabling threaded logging for severity WARNING
I0206 23:21:41.501126    6454 mlock.cc:217] mlock()-ed 4096 bytes for BuildID, using 1 syscalls.
I0206 23:21:41.511976    6454 init-domain.cc:105] Fiber init: default domain = futex, concurrency = 123, prefix = futex-default
I0206 23:21:41.536622    6454 singleton_tpu_states_manager.cc:45] TpuStatesManager::GetOrCreate(): no tpu system exists. Creating a new tpu system.
W0206 23:21:41.536843    6454 tpu_version_flag.cc:57] No hardware is found. Using default TPU version: ba16c7433
defdet commented 7 months ago

@CaptainStiggz, not sure if it's the same as in my case but I've noticed you have --version=tpu-vm-base. In gcloud docs, it's explicitly stated: "You must manually install JAX on your TPU VM, because there is no JAX-specific TPU software version. For all TPU versions, use tpu-ubuntu2204-base. The correct version of libtpu.so is automatically installed when you install JAX.". For me, switching versions helped instantly.

CaptainStiggz commented 7 months ago

@defdet this appears to have worked. Thanks! A little confusing because ubuntu-2204-base is not listed as a software option in the console, but works when executed from the command line. Made me think it was not a supported software option for v5.