Closed notlober closed 3 months ago
I don't think we have lots of documentation around libtpu and interfaces to TPUs other than what's exposed in the Python libraries. But, these days you actually can use the "high level" libraries like JAX to generate very low level code for TPUs using Pallas. I'd highly recommend reading through the guides I linked (they also contain links to TPU papers). If you have any other questions, let us know.
I'm going to close this because I think the pointer to Pallas docs is the best we can provide.
I'm looking to build an automatic differentiation library for TPUs without using high-level front-ends like TensorFlow/JAX/PyTorch-XLA, but I'm finding information about lower-level TPU usage is practically non-existent.
Specifically, I'm interested in:
Are there any insights or suggestions on how to approach this, particularly regarding TPU support? Any ideas or help would be greatly appreciated.
I understand that some of this information might be proprietary, but any guidance on what is possible or available would be very helpful.