openxla / xla

A machine learning compiler for GPUs, CPUs, and ML accelerators
Apache License 2.0
2.73k stars 438 forks source link

How to build for TPU ? #11599

Open janpfeifer opened 7 months ago

janpfeifer commented 7 months ago

I'm an ML Framework (GoMLX, for Go language) developer using XLA (XlaBuilder for now, but at some point changing to PJRT), and I wanted to run some of my trainers in Cloud TPUs.

How do I build XLA for TPUs ? The configuire.py doesn't seem to have such a --build=TPU option. Or is there a pre-built distribution of XLA libraries for TPUs I could use ?

Thanks!

cheshire commented 7 months ago

Hey,

the TPU backend is non-OSS, so you can't build it from the GH repo.

janpfeifer commented 7 months ago

Oh, I'm sad to hear that ... I used to see TPU related code in the repo.

But if non-OSS, how can I use Cloud TPUs with my project ? Is there a pre-compiled binary/.so that I can link to ?

cheshire commented 7 months ago

Yeah there's a libtpu.so you get which gives you access: https://cloud.google.com/tpu/docs/runtimes

janpfeifer commented 7 months ago

Yeah there's a libtpu.so you get which gives you access: https://cloud.google.com/tpu/docs/runtimes

Thanks. But what is inside this libtpu.so ? Would I have to extract the exported symbols, and reverse-engineer some .h files, and somehow connect the OSS XLA to it ? The page https://cloud.google.com/tpu/docs/runtimes doesn't provide any details.

cheshire commented 7 months ago

I think people don't access it at that level - they would use XLA (or JAX/PT) interfaces, and the implementation details would be hidden behind PjRT/StableHLO APIs.

janpfeifer commented 7 months ago

(I updated the description of the bug to clarify that my question is from the point of view of someone using XLA to create an ML framework in another language that is not Python)

cheshire commented 6 months ago

Maybe the question should be then non-Python non-C++ StableHLO/PjRT/xla_builder bindings?

We've used to have quite a few other frontends: Julia, Elixir, etc, but those don't look very active these days.

janpfeifer commented 6 months ago

Well, I'm using xla_builder, but I'd happily change to use StableHLO/PjRT if it would enable using TPUs.

I modeled the very first version of the XLA bindings for Go on the Elixir version.

notlober commented 3 months ago

@janpfeifer I'm looking to build an automatic differentiation library without using front-ends like TensorFlow/JAX/PyTorch-XLA, but I'm finding information about lower-level TPU usage is practically non-existent. I see you were interested in creating an ML framework too. Do you have any insights or suggestions on how to approach this, particularly regarding TPU support? Any ideas or help would be greatly appreciated.

janpfeifer commented 3 months ago

hi @notlober ,

A couple of high level aspects to consider:

  1. My understanding is that currently the XLA JIT compilation and optimization is behind a more well defined C-API called PJRT. It uses a .so (or DLL) model, the so called PJRT plugin.

This allows for proprietary accelerators (hardware chip makers) to distribute binary plugins.

TPU is such a proprietary plugin: you will probably find it in the Google Cloud boxes with TPU support. And presumably -- I haven't tried yet -- you can develop against a standard CPU or GPU PJRT plugin, and it will work all the same ...

I recently separated my ML framework (GoMLX) implementation from a Go wrapper around the PJRT (now in a separate github project called gopjrt) -- I highly recommend separating these things: PJRT will give you JIT-compilation and fast execution of ML ops, and nothing else. You can then build your thing on top of it.

If you are in C++ land, you can use the PJRT wrappers already offered under github.com/openxla/xla repository -- but they are hard to compile/link.

  1. The PJRT consumes an "HLO", some intermediary proto representation of a the accelerated program you are going to execute. You still need to figure out how to translate your language/library to that.

In C++, openxla/xla offer a library called XlaBuilder: it is hard to link/compile, but it works very well. I'm wrapping that in the gopjrt project.

There is this "StableHLO" (and many variations) which is a textual language -- not human friendly to read/write, but enough for debugging or to generate -- which you can consider. But then you'll have to figure how to convert that to the HLO proto needed -- these tools are in openxla/xla repository somewhere, but the same concerns apply.

I hope this helps. If you can read Go (or for the Bazel's WORKSPACE and BUILD files) you may want to bootstrap your code from my repositories and change from there.

It's lots of work, but XLA runs really well, and has been very stable. I've been training many models with my framework without any issues (mostly in CPU and GPU, I haven't yet tried TPU, it's in my TODO list -- I've been postponing due to the $$$ renting the cloud boxes).