openxla / xla

A machine learning compiler for GPUs, CPUs, and ML accelerators
Apache License 2.0
2.41k stars 361 forks source link

[xla:cpu] Pre-construct call frames for typed-FFI custom calls in the thunk runtime. #14314

Closed copybara-service[bot] closed 1 week ago

copybara-service[bot] commented 1 week ago

[xla:cpu] Pre-construct call frames for typed-FFI custom calls in the thunk runtime.

Pre-contruct the call frame at compile time when creating the CustomCall thunk instead of constructing on-the-fly while running the thunk. Since only device memory addresses change at runtime, we can pre-populate a prototype call frame with attributes and other buffer information, then update the addresses when the thunk is executed.

Also add benchmarks to measure the difference. Must run with XLA_FLAGS=--xla_cpu_use_thunk_runtime=true.

Thunk runtime (this commit, "new") vs thunk runtime (previous commit, "old"): name old cpu/op new cpu/op delta BM_CustomCall_Minimal/process_time 944ns ± 3% 853ns ± 0% -9.64% (p=0.016 n=5+4) BM_CustomCall_16IntAttributes/process_time 29.9µs ± 2% 1.0µs ± 5% -96.58% (p=0.008 n=5+5) BM_CustomCall_16FloatBuffers/process_time 3.23µs ± 1% 2.79µs ± 2% -13.44% (p=0.008 n=5+5)

Thunk runtime (this commit, "new") vs classic runtime ("old"): name old cpu/op new cpu/op delta BM_CustomCall_Minimal/process_time 852ns ± 4% 873ns ± 3% ~ (p=0.151 n=5+5) BM_CustomCall_16IntAttributes/process_time 30.1µs ± 2% 1.0µs ± 0% -96.67% (p=0.016 n=5+4) BM_CustomCall_16FloatBuffers/process_time 3.27µs ± 5% 2.81µs ± 3% -13.94% (p=0.008 n=5+5)