iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.77k stars 601 forks source link

Add user/framework-configurable named workload section markers #16051

Open ScottTodd opened 9 months ago

ScottTodd commented 9 months ago

I don't think we have something similar to this already, and I don't recall many similar issues/discussions (perhaps https://github.com/openxla/iree/issues/13145).

A common profiling request I see is "I want a breakdown of time spent in each model component", where the user has some pretty high level view of what "component" is. For example, https://bbycroft.net/llm has "Embedding", "Self Attention", "Transformer", "Softmax", and other components for nano-gpt. Some "components" could be inferred from top level functions or split into separate .vmfb files, but frameworks and importers probably don't use quite the level of granularity that a user would expect.

Vulkan has this extension: https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/VK_EXT_debug_marker.html

The VK_EXT_debug_marker extension is a device extension. It introduces concepts of object naming and tagging, for better tracking of Vulkan objects, as well as additional commands for recording annotations of named sections of a workload to aid organization and offline analysis in external tools.

I think a similar concept could be added to IREE as a way to pass down some annotations about conceptual parts of an input program. These annotations would be carried through to the HAL dialect, ideally in a way that preserves the user-provided signals without interfering with the compiler's ability to fuse/de-dup/link/etc.. Tools like Tracy could then include statistics for the given sections (min/max/average memory, start time, end time, processor utilization, etc.).

Another idea would be to build some tooling/scripts/Python APIs/etc. to help with slicing programs into the sorts of pieces that could then be profiled independently. In the nano-gpt case that would mean running independent benchmarks for "Embedding", "Self Attention", etc. Would that still yield useful data to people analyzing/improving performance for those sorts of programs?

Thoughts?

ScottTodd commented 9 months ago

Had a few more ideas. We already have robust infrastructure for source locations tracking in MLIR, so we could continue building off of that instead of adding a new annotation. There is a bit of a tradeoff between support in the existing infrastructure and generality/ambiguity though.

Here's an example source location from Python (from https://github.com/openxla/iree/pull/13500): loc(fused[callsite("jit(run_cnn)/jit(main)/CNN/jit(relu)/max"("/usr/local/google/home/scotttodd/code/scratch/iree/jax/jax_cnn.py":9:0) at "jit(run_cnn)/jit(main)/CNN/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=relu keep_unused=False inline=False]"("/usr/local/google/home/scotttodd/code/scratch/iree/jax/jax_cnn.py":9:0)), "jit(run_cnn)/jit(main)/CNN/Conv_0/conv_general_dilated[window_strides=(1, 1) padding=((1, 1), (1, 1)) lhs_dilation=(1, 1) rhs_dilation=(1, 1) dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2)) feature_group_count=1 batch_group_count=1 precision=None preferred_element_type=None]"("/usr/local/google/home/scotttodd/code/scratch/iree/jax/.venv/lib/python3.10/site-packages/flax/linen/linear.py":449:0)])

I'd need to look further at the Python / MLIR sources for current programs of interest, but I suspect they aren't as neatly organized as

def component_embedding():
  layer_1()
  op_2()

def component_self_attention():
  layer_1()
  layer_2()

def program():
  component_embedding()
  component_self_attention()

If they are, we could do a bit of work to extract either function, class, or file names from source locations and use those as executable sources that are already plumbed through to Tracy (see also https://github.com/openxla/iree/issues/15699). If we want to be a bit more flexible, we could have a flag that lets you tell the compiler which style of source location to pin. I still feel like explicit markers in Python (or scripted/manual edits of the .mlir files) would be easier to understand, if implicit source location plumbing is unreliable.

ScottTodd commented 9 months ago

If we go the source location route, some extra tooling to collect statistics for dispatches grouped by their source locations would be useful. Maybe another tracing sink, or a script that runs on .tracy files, etc.

Another thing to consider is if dispatches actually include all relevant data (timing, memory, etc.), or if these sorts of analyses would also care about other parts of the program (pipeline initialization, data upload/download, etc.). I don't think we currently associate such operations with any MLIR source locations or have a lightweight / CLI-friendly way to extract that data (besides running Tracy and getting all the data).

benvanik commented 9 months ago

let's chat when I'm back - this stuff is tricky to make actually useful (especially when involving source locations - those make things 99x less likely to be useful to actual users, just to compiler devs)

ScottTodd commented 9 months ago

I'm still collecting requirements/samples and brainstorming here. For "user/framework-configurable markers" I think leaning on Python decorators like shark_turbine.aot.jittable could make sense. Either

A) Verifying that jittable already does what we want somehow B) Adding new behavior to jittable C) Adding a new, compatible decorator (or an option on that decorator somehow)

Example usage:

Similar plumbing through JAX should be possible, at least with iree-jax. PJRT ... I'm not sure. When I last looked at how PJRT operated, it JIT'd all sorts of slices of the program in ways that could be difficult to reason about and map between the user-authored Python program and what the compiler ultimately produces.