Closed qihqi closed 3 weeks ago
This is used for host-offloading.
example code of what jax emits:
def policy(prim, *avals, **params) -> Offloadable: return Offloadable(src='device', dst='pinned_host') @functools.partial(jax.remat, policy=policy) def f(x): x = jnp.sin(x) x = jnp.sin(x) return jnp.sum(x)
becomes:
module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<16xf32> {mhlo.layout_mode = "default"}) -> (tensor<16xf32> {jax.result_info = "", mhlo.layout_mode = "default"}) { %0 = stablehlo.sine %arg0 : tensor<16xf32> %1 = stablehlo.cosine %arg0 : tensor<16xf32> %2 = stablehlo.custom_call @annotate_device_placement(%1) {has_side_effect = true, mhlo.frontend_attributes = {_xla_buffer_placement = "pinned_host"}} : (tensor<16xf32>) -> tensor<16xf32> %3 = stablehlo.cosine %0 : tensor<16xf32> %4 = stablehlo.custom_call @annotate_device_placement(%3) {has_side_effect = true, mhlo.frontend_attributes = {_xla_buffer_placement = "pinned_host"}} : (tensor<16xf32>) -> tensor<16xf32> %cst = stablehlo.constant dense<1.000000e+00> : tensor<f32> %5:3 = stablehlo.optimization_barrier %2, %4, %cst : tensor<16xf32>, tensor<16xf32>, tensor<f32> %6 = stablehlo.custom_call @annotate_device_placement(%5#0) {has_side_effect = true, mhlo.frontend_attributes = {_xla_buffer_placement = "device"}} : (tensor<16xf32>) -> tensor<16xf32> %7 = stablehlo.custom_call @annotate_device_placement(%5#1) {has_side_effect = true, mhlo.frontend_attributes = {_xla_buffer_placement = "device"}} : (tensor<16xf32>) -> tensor<16xf32> %8 = stablehlo.broadcast_in_dim %5#2, dims = [] : (tensor<f32>) -> tensor<16xf32> %9 = stablehlo.multiply %8, %7 : tensor<16xf32> %10 = stablehlo.multiply %9, %6 : tensor<16xf32> return %10 : tensor<16xf32> } }
Need to format python files to pass linter
This is used for host-offloading.
example code of what jax emits:
becomes: