pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.44k stars 453 forks source link

[question] Seeking information on low-level TPU interaction and libtpu.so API #7803

Open notlober opened 1 month ago

notlober commented 1 month ago

I'm looking to build an automatic differentiation library for TPUs without using high-level front-ends like TensorFlow/JAX/PyTorch-XLA, but I'm finding information about lower-level TPU usage is practically non-existent.

Specifically, I'm interested in:

  1. How to interact with TPUs at a lower level than what's typically exposed in TensorFlow
  2. Information about the libtpu.so library and its API
  3. Any resources or documentation on implementing custom TPU operations

Are there any insights or suggestions on how to approach this, particularly regarding TPU support? Any ideas or help would be greatly appreciated.

I understand that some of this information might be proprietary, but any guidance on what is possible or available would be very helpful.

JackCaoG commented 1 month ago

@will-cromar should be able to share some information.

will-cromar commented 1 month ago

All three frameworks interact with libtpu through the PJRT plugin API. Most of the core API for PJRT is documented in comments here: https://github.com/openxla/xla/blob/main/xla/pjrt/pjrt_client.h

Almost all of our interactions with PJRT are in this folder, and it's largely independent from PyTorch itself: https://github.com/pytorch/xla/tree/master/torch_xla/csrc/runtime

Specifically, to create a PJRT TPU client, you would need to go through the PjRtCApiClient similar to this (device_type = "tpu", library_path = "/path/to/libtpu.so"): https://github.com/pytorch/xla/blob/dd3b00ca95f455ffd9eec15803429420efbd106a/torch_xla/csrc/runtime/pjrt_registry.cc#L118-L126

Once you have a client instantiated, then your interactions are going to look a lot like this example from JAX: https://github.com/google/jax/blob/main/examples/jax_cpp/main.cc

We use the PJRT C++ API direcly, but it's worth noting that (other than the example above) JAX actually mainly interacts with PJRT through Python bindings. I not nearly as familiar with those, so you'll have better luck asking in their repository if you want to use the same bindings.

The framework code outside of libtpu.so is all open source. I'm happy to help if you have any questions about the PJRT C++ API.