openxla / xla

A machine learning compiler for GPUs, CPUs, and ML accelerators
Apache License 2.0
2.56k stars 400 forks source link

[Epic] Migrate XLA:CPU custom calls to use typed FFI #10059

Open penpornk opened 6 months ago

penpornk commented 6 months ago

Previous XLA custom call API versions pass parameters as void** buffers. The new version, typed FFI, allows passing metadata such as data type and shape along with each parameter, which greatly improves the overall programmability and flexibility.

Custom call API comparison between versions: https://github.com/openxla/xla/blob/76e1d730d4546b9a15c187102e3ee37734ff8ffa/xla/service/hlo.proto#L51-L111

XLA:GPU already has typed FFI support (part of the Thunks runtime work). Per our goal to share runtime components between XLA:GPU and XLA:CPU, we will be migrating XLA:CPU to use typed FFI for custom calls as well.

See more custom call examples (v1-v4) in issue https://github.com/openxla/xla/issues/8319.

### Tasks
- [ ] https://github.com/openxla/xla/issues/10056
- [ ] https://github.com/openxla/xla/issues/10060
- [ ] https://github.com/openxla/xla/issues/10062
Zantares commented 3 months ago

Hello @penpornk , Intel GPU is also following this new FFI design to extend the custom call, but we found some pain points in the current design after finishing a simple PoC:

  1. Currently it only has in-tree examples, and the implementation becomes more complex if the custom call is out-of-tree. Please check the comments here: https://github.com/openxla/xla/issues/8319#issuecomment-2109755854. Is it possible to provide some out-of-tree cases like how to reimplement custom calls as FFI calls in jaxlib?
  2. The custom call instruction has some inherent members like window https://github.com/openxla/xla/blob/7491cd61955271984d422b1de4338e8a1b70c09a/xla/client/xla_builder.cc#L2667-L2669 that are widely used in real custom callees. Could Google add the parsing of these members to the thunk process?
  3. According to the FFI design, backend_config_str will be parsed to attribution map https://github.com/openxla/xla/blob/91a45ef9faf67d440d4d4c8e9dec5fc8c0f5929a/xla/service/gpu/ir_emitter_unnested.cc#L1407-L1418 in thunk process, but in our PoC this is failed because some escape characters in backend_config_str can't be recognized. We will check whether the newest main branch can work.
Zantares commented 3 months ago

A follow-up of #1 in previous comments, seems we need to wrap stream and other needed handlers as int64_t vars if the internal class ServiceExecutableRunOptions are not preferred anymore.

pparuzel commented 3 months ago

About "[...] how to reimplement custom calls as FFI calls in jaxlib?", there is ongoing work currently on the CPU backend. Maybe you will find it useful: google/jax#21574

Zantares commented 2 months ago

About "[...] how to reimplement custom calls as FFI calls in jaxlib?", there is ongoing work currently on the CPU backend. Maybe you will find it useful: google/jax#21574

Thanks, we are looking into it.

Zantares commented 1 month ago

Hi @penpornk , we prepared a commit to reproduce the issue #3 mentioned in https://github.com/openxla/xla/issues/10059#issuecomment-2147712924, could you help to check it? We can submit a public issue if needed, thanks.