openxla / xla

A machine learning compiler for GPUs, CPUs, and ML accelerators
Apache License 2.0
2.39k stars 356 forks source link

[PJRT:GPU] Implement copying buffers to pinned host memory space #14268

Closed jaro-sevcik closed 2 days ago

jaro-sevcik commented 3 days ago

@jyingl3 Could you take a look, please?

This is to make jax.device_put(..., ...Sharding(..., memory_kind="pinned_host")) work.