jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.65k stars 2.83k forks source link

Seeking information on low-level TPU interaction and libtpu.so API #22835

Closed notlober closed 3 months ago

notlober commented 4 months ago

I'm looking to build an automatic differentiation library for TPUs without using high-level front-ends like TensorFlow/JAX/PyTorch-XLA, but I'm finding information about lower-level TPU usage is practically non-existent.

Specifically, I'm interested in:

  1. How to interact with TPUs at a lower level than what's typically exposed in TensorFlow
  2. Information about the libtpu.so library and its API
  3. Any resources or documentation on implementing custom TPU operations

Are there any insights or suggestions on how to approach this, particularly regarding TPU support? Any ideas or help would be greatly appreciated.

I understand that some of this information might be proprietary, but any guidance on what is possible or available would be very helpful.

apaszke commented 4 months ago

I don't think we have lots of documentation around libtpu and interfaces to TPUs other than what's exposed in the Python libraries. But, these days you actually can use the "high level" libraries like JAX to generate very low level code for TPUs using Pallas. I'd highly recommend reading through the guides I linked (they also contain links to TPU papers). If you have any other questions, let us know.

mattjj commented 3 months ago

I'm going to close this because I think the pointer to Pallas docs is the best we can provide.