pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.5k stars 482 forks source link

Add functions to emit custom call to place a buffer to host and device. #8350

Closed qihqi closed 3 weeks ago

qihqi commented 3 weeks ago

This is used for host-offloading.

example code of what jax emits:

def policy(prim, *avals, **params) -> Offloadable:
  return Offloadable(src='device', dst='pinned_host')

@functools.partial(jax.remat, policy=policy)
def f(x):
  x = jnp.sin(x)
  x = jnp.sin(x)
  return jnp.sum(x)

becomes:

module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<16xf32> {mhlo.layout_mode = "default"}) -> (tensor<16xf32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    %0 = stablehlo.sine %arg0 : tensor<16xf32>
    %1 = stablehlo.cosine %arg0 : tensor<16xf32>
    %2 = stablehlo.custom_call @annotate_device_placement(%1) {has_side_effect = true, mhlo.frontend_attributes = {_xla_buffer_placement = "pinned_host"}} : (tensor<16xf32>) -> tensor<16xf32>
    %3 = stablehlo.cosine %0 : tensor<16xf32>
    %4 = stablehlo.custom_call @annotate_device_placement(%3) {has_side_effect = true, mhlo.frontend_attributes = {_xla_buffer_placement = "pinned_host"}} : (tensor<16xf32>) -> tensor<16xf32>
    %cst = stablehlo.constant dense<1.000000e+00> : tensor<f32>
    %5:3 = stablehlo.optimization_barrier %2, %4, %cst : tensor<16xf32>, tensor<16xf32>, tensor<f32>
    %6 = stablehlo.custom_call @annotate_device_placement(%5#0) {has_side_effect = true, mhlo.frontend_attributes = {_xla_buffer_placement = "device"}} : (tensor<16xf32>) -> tensor<16xf32>
    %7 = stablehlo.custom_call @annotate_device_placement(%5#1) {has_side_effect = true, mhlo.frontend_attributes = {_xla_buffer_placement = "device"}} : (tensor<16xf32>) -> tensor<16xf32>
    %8 = stablehlo.broadcast_in_dim %5#2, dims = [] : (tensor<f32>) -> tensor<16xf32>
    %9 = stablehlo.multiply %8, %7 : tensor<16xf32>
    %10 = stablehlo.multiply %9, %6 : tensor<16xf32>
    return %10 : tensor<16xf32>
  }
}
tengyifei commented 3 weeks ago

Need to format python files to pass linter