openxla / xla

A machine learning compiler for GPUs, CPUs, and ML accelerators
Apache License 2.0
2.55k stars 394 forks source link

[NVIDIA GPU] Add debug flag for syntactic sugar #14865

Closed terryysun closed 1 month ago

terryysun commented 1 month ago

This is a followup PR of https://github.com/openxla/xla/pull/14344. Originally the issue was HLO dumping and NVTX marker naming are inconsistent, with https://github.com/openxla/xla/pull/14344 now both of them are wrapped by syntactic sugar. There are some cases, especially when debugging, the original naming without syntactic sugar is helpful. This PR adds a debug flag to control the syntactic sugar of both HLO dumping and NVTX marker.

terryysun commented 1 month ago

Posting HLO examples showing the effect of the added flag:

with --xla_syntax_sugar_async_ops=true

HloModule pmap__lambda_, is_scheduled=true, entry_computation_layout={(f32[4,8]{1,0})->f32[8,4]{1,0}}, replica_count=8, frontend_attributes={fingerprint_before_lhs="47ebad045c9d61518cb31be75fe36648"}

wrapped_transpose_computation {
  param_0.1 = f32[4,8]{1,0} parameter(0)
  ROOT transpose.2.1 = f32[8,4]{1,0} transpose(param_0.1), dimensions={1,0}
}

ENTRY main.9 {
  Arg_0.1.0 = f32[4,8]{1,0} parameter(0)
  wrapped_transpose = f32[8,4]{1,0} fusion(Arg_0.1.0), kind=kLoop, calls=wrapped_transpose_computation
  bitcast.24.0 = f32[1,4,8]{1,0,2} bitcast(wrapped_transpose)
  all-to-all-start = ((f32[1,4,8]{1,0,2}), f32[1,4,8]{1,0,2}) all-to-all-start(bitcast.24.0), replica_groups={{0,1,2,3,4,5,6,7}}, dimensions={2}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"collective_backend_config":{"is_sync":true,"no_parallel_custom_call":false},"force_earliest_schedule":false}
  all-to-all-done = f32[1,4,8]{1,0,2} all-to-all-done(all-to-all-start)
  ROOT bitcast.2.0 = f32[8,4]{1,0} bitcast(all-to-all-done)
} // main.9

with --xla_syntax_sugar_async_ops=false (default)

HloModule pmap__lambda_, is_scheduled=true, entry_computation_layout={(f32[4,8]{1,0})->f32[8,4]{1,0}}, replica_count=8, frontend_attributes={fingerprint_before_lhs="47ebad045c9d61518cb31be75fe36648"}

async_computation {
  param_0 = f32[1,4,8]{1,0,2} parameter(0)
  ROOT all-to-all.3.1 = f32[1,4,8]{1,0,2} all-to-all(param_0), replica_groups={{0,1,2,3,4,5,6,7}}, dimensions={2}
}

wrapped_transpose_computation {
  param_0.1 = f32[4,8]{1,0} parameter(0)
  ROOT transpose.2.1 = f32[8,4]{1,0} transpose(param_0.1), dimensions={1,0}
}

ENTRY main.9 {
  Arg_0.1.0 = f32[4,8]{1,0} parameter(0)
  wrapped_transpose = f32[8,4]{1,0} fusion(Arg_0.1.0), kind=kLoop, calls=wrapped_transpose_computation
  bitcast.24.0 = f32[1,4,8]{1,0,2} bitcast(wrapped_transpose)
  all-to-all-start = ((f32[1,4,8]{1,0,2}), f32[1,4,8]{1,0,2}) async-start(bitcast.24.0), calls=async_computation, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"collective_backend_config":{"is_sync":true,"no_parallel_custom_call":false},"force_earliest_schedule":false}
  all-to-all-done = f32[1,4,8]{1,0,2} async-done(all-to-all-start)
  ROOT bitcast.2.0 = f32[8,4]{1,0} bitcast(all-to-all-done)
} // main.9