iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.85k stars 621 forks source link

codegen producing register count overflow for hip target gfx942 with 8 devices #18923

Open KyleHerndon opened 4 weeks ago

KyleHerndon commented 4 weeks ago

What happened?

I'm getting a register count overflow when trying to run llama3.1_405b_fp16 for 8 HIP devices targeting gfx942 iree/runtime/src/iree/vm/bytecode/verifier.c:345: RESOURCE_EXHAUSTED; register count overflow; loading bytecode module at '405b_f16_tp8.vmfb'; loading modules and dependencies; creating run context

Steps to reproduce your issue

Using llama3.1_405b_f16_tp8.mlir obtained from the SHARK-Platform model export process Compile the mlir

iree-build/tools/iree-compile /data/llama-3.1/405b/llama3.1_405b_fp16_tp8.mlir   --iree-hal-target-device=hip[0] --iree-hal-target-device=rhip[1] --iree-hal-target-device=hip[2] --iree-hal-target-device=hip[3] --iree-hal-target-device=hip[4] --iree-hal-target-device=hip[5] --iree-hal-target-device=hip[6] --iree-hal-target-device=hip[7]   --iree-hal-executable-debug-level=3 --iree-hip-target=gfx942   -o 405b_f16_tp8.vmfb

Attempt to run the vmfb

iree-build/tools/iree-run-module --parameters="model=/data/llama3.1/405b/llama3.1_405b_fp16_tp8_parameters.rank0.irpa" --parameters="model=/data/llama3.1/405b/llama3.1_405b_fp16_tp8_parameters.rank1.irpa" --parameters="model=/data/llama3.1/405b/llama3.1_405b_fp16_tp8_parameters.rank2.irpa" --parameters="model=/data/llama3.1/405b/llama3.1_405b_fp16_tp8_parameters.rank3.irpa" --parameters="model=/data/llama3.1/405b/llama3.1_405b_fp16_tp8_parameters.rank4.irpa" --parameters="model=/data/llama3.1/405b/llama3.1_405b_fp16_tp8_parameters.rank5.irpa" --parameters="model=/data/llama3.1/405b/llama3.1_405b_fp16_tp8_parameters.rank6.irpa" --parameters="model=/data/llama3.1/405b/llama3.1_405b_fp16_tp8_parameters.rank7.irpa"  --module=405b_f16_tp8.vmfb --function=prefill_bs4   --input=4x1xi64=0   --input=4xi64=1   --input=4x1xi64=0,1,2,3   --input=1x2097152xf16 --device=hip[0] --device=hip[1] --device=hip[2] --device=hip[3] --device=hip[4] --device=hip[5] --device=hip[6] --device=hip[7]

What component(s) does this issue relate to?

Compiler

Version information

3b751a4d2

Additional context

MLIR vmfb The sharded parameter files are not necessary to replicate this error but I can upload these to a private location if necessary.

stellaraccident commented 3 weeks ago

@benvanik I think this was one of those "if you ever need this, we want to take a close look" kinds of things. With that said, we to a program that was already large and multiplied it by 8...

benvanik commented 3 weeks ago

That means it's hitting 65k live values. There's no going bigger than that - just making the compiler be better. Best bet is to split each thing into its own function instead of inlining it all into one.

benvanik commented 3 weeks ago

(a --compile-to=vm dump would be useful in seeing how the hell there are 65k live values - it may just be a bug/pathological behavior in the crappy register allocator I begrudgingly wrote 5 years ago :)

stellaraccident commented 3 weeks ago

Ah this is live values. That is a lot. My bet is on something that has been pathological for a while and then we just multiplied it by 8.

benvanik commented 3 weeks ago

Yeah, maybe it's got like 40k unique constants that are used multiple times so their lifetimes overlap. I've got a TODO for making constants better but still would not be great (no compiler/JIT/etc ever wants to see a function with tens of thousands of locals :). We need the compiler to spill (e.g. to a temporary vm.list acting as the stack) or rematerialize (clone constants near uses). The former sucks for performance but if the latter was the issue it'd be ok.

benvanik commented 3 weeks ago

(having one non-inlinable function per shard is going to be the best, though - you'll get way faster compiles too for things nested on function scope)

stellaraccident commented 3 weeks ago

That would be best but it's kind of the opposite of how any frontend I know of will want to split things.

benvanik commented 3 weeks ago

It can be something we do near the frontend/global opt (affinity analysis -> slice -> outline).

MLIR/linalg/etc is not built for big things -- and in some cases not even medium ones -- and if we want reasonable compile times we'll need to do something other than chew on bajillion op functions - if the frontend can simplify things that's best but otherwise we need to :(

stellaraccident commented 3 weeks ago

Yeah, that's what I meant... Need to partition on some axis

stellaraccident commented 3 weeks ago

But we need to look at this specific case. In the optimized single device cases I was producing, things weren't crazy. My vote is on pathological machinery somewhere

benvanik commented 3 weeks ago

Yeah --compile-to=vm will be the best place to start - I bet there's some low hanging fruit.

stellaraccident commented 3 weeks ago

Here is IR (as bytecode) of /home/stella/tmp/llamabig/llama3.1_405b_f16_tp8.mlir.stream : https://drive.google.com/file/d/1Y_G7helvJ15R8DVcVkcvbqnzBRzVTiBo/view?usp=sharing That's as far as I got tonight. But I was poking at things. The profile report is 38GiB.

stellaraccident commented 3 weeks ago

--compile-to=vm IR: https://drive.google.com/file/d/1X27ro1iQEyS-Kvp5HEqUBoC8gnfyNNX4/view?usp=sharing

This phase took a mere 13m. A quick look at the profile shows almost all of that is in repeated application of two patterns in a canonicalizer which should probably be a pass.

Compilation to stream took 56m. There was a more varied assortment of bad things -- almost all of them canonicalization.

I'll need to write some tools to make sense of these profiles.

benvanik commented 3 weeks ago

Taking a peek.

Quick suggestion is to enable command buffer reuse/memoization: --iree-hal-indirect-command-buffers=true --iree-hal-memoization=true It doesn't handle dynamic cases well but most of this is not dynamic and may be able to be helped by at least the outlining. That will be the default once we are off HIP (off now because we can't do indirect dispatch there and support reuse in general). --iree-stream-resource-memory-model=discrete is likely to help too. That should be the default.

This is garbage-shaped, though, so there will be issues. Lots of them look low-hanging thankfully.

First sadness:

    vm.global.ref private mutable @__device_3_executable_7_prefill_bs4$async_dispatch_51 : !vm.ref<!hal.executable>
    vm.global.ref private mutable @__device_3_executable_8_prefill_bs4$async_dispatch_67 : !vm.ref<!hal.executable>
    vm.global.ref private mutable @__device_3_executable_9_prefill_bs4$async_dispatch_107 : !vm.ref<!hal.executable>
    vm.global.ref private mutable @__device_3_executable_10_prefill_bs4$async_dispatch_115 : !vm.ref<!hal.executable>
    vm.global.ref private mutable @__device_3_executable_11_prefill_bs4$async_dispatch_123 : !vm.ref<!hal.executable>
    vm.global.ref private mutable @__device_3_executable_12_prefill_bs4$async_dispatch_131 : !vm.ref<!hal.executable>
    vm.global.ref private mutable @__device_3_executable_13_prefill_bs4$async_dispatch_132 : !vm.ref<!hal.executable>
    vm.global.ref private mutable @__device_3_executable_14_prefill_bs4$async_dispatch_147 : !vm.ref<!hal.executable>
    vm.global.ref private mutable @__device_3_executable_15_prefill_bs4$async_dispatch_155 : !vm.ref<!hal.executable>
    vm.global.ref private mutable @__device_3_executable_16_prefill_bs4$async_dispatch_156 : !vm.ref<!hal.executable>
    vm.global.ref private mutable @__device_3_executable_17_prefill_bs4$async_dispatch_171 : !vm.ref<!hal.executable>
    vm.global.ref private mutable @__device_3_executable_18_prefill_bs4$async_dispatch_179 : !vm.ref<!hal.executable>
    vm.global.ref private mutable @__device_3_executable_19_prefill_bs4$async_dispatch_180 : !vm.ref<!hal.executable>
    vm.global.ref private mutable @__device_3_executable_20_prefill_bs4$async_dispatch_195 : !vm.ref<!hal.executable>
    vm.global.ref private mutable @__device_3_executable_21_prefill_bs4$async_dispatch_211 : !vm.ref<!hal.executable>
    vm.global.ref private mutable @__device_3_executable_22_prefill_bs4$async_dispatch_227 : !vm.ref<!hal.executable>
    vm.global.ref private mutable @__device_3_executable_23_prefill_bs4$async_dispatch_235 : !vm.ref<!hal.executable>
    vm.global.ref private mutable @__device_3_executable_24_prefill_bs4$async_dispatch_75 : !vm.ref<!hal.executable>
    vm.global.ref private mutable @__device_3_executable_25_prefill_bs4$async_dispatch_83 : !vm.ref<!hal.executable>
    vm.global.ref private mutable @__device_3_executable_26_prefill_bs4$async_dispatch_91 : !vm.ref<!hal.executable>
    vm.global.ref private mutable @__device_3_executable_27_prefill_bs4$async_dispatch_92 : !vm.ref<!hal.executable>

...

*We need to get linking fixed in the ROCM target. 200 executables 8 devices vs 8 total.** This needs attention soon - have been kicking that can since no one wanted to do it for CUDA years ago. Maintaining 1600 executables is not cheap.

Second sadness:

    vm.global.ref private mutable @__auto.blk.117.attn_v.weight.shard.3 : !vm.ref<!hal.buffer>
    vm.global.ref private mutable @__auto.blk.117.attn_v.weight.shard.4 : !vm.ref<!hal.buffer>
    vm.global.ref private mutable @__auto.blk.117.attn_v.weight.shard.5 : !vm.ref<!hal.buffer>
    vm.global.ref private mutable @__auto.blk.117.attn_v.weight.shard.6 : !vm.ref<!hal.buffer>
    vm.global.ref private mutable @__auto.blk.117.attn_v.weight.shard.7 : !vm.ref<!hal.buffer>
    vm.global.ref private mutable @__auto.blk.117.attn_output.weight.shard.0 : !vm.ref<!hal.buffer>
    vm.global.ref private mutable @__auto.blk.117.attn_output.weight.shard.1 : !vm.ref<!hal.buffer>
    vm.global.ref private mutable @__auto.blk.117.attn_output.weight.shard.2 : !vm.ref<!hal.buffer>
    vm.global.ref private mutable @__auto.blk.117.attn_output.weight.shard.3 : !vm.ref<!hal.buffer>
    vm.global.ref private mutable @__auto.blk.117.attn_output.weight.shard.4 : !vm.ref<!hal.buffer>
    vm.global.ref private mutable @__auto.blk.117.attn_output.weight.shard.5 : !vm.ref<!hal.buffer>
    vm.global.ref private mutable @__auto.blk.117.attn_output.weight.shard.6 : !vm.ref<!hal.buffer>
    vm.global.ref private mutable @__auto.blk.117.attn_output.weight.shard.7 : !vm.ref<!hal.buffer>
    vm.global.ref private mutable @__auto.blk.117.ffn_norm.weight : !vm.ref<!hal.buffer>
    vm.global.ref private mutable @__auto.blk.117.ffn_norm.weight$1 : !vm.ref<!hal.buffer>
    vm.global.ref private mutable @__auto.blk.117.ffn_norm.weight$2 : !vm.ref<!hal.buffer>
    vm.global.ref private mutable @__auto.blk.117.ffn_norm.weight$3 : !vm.ref<!hal.buffer>
    vm.global.ref private mutable @__auto.blk.117.ffn_norm.weight$4 : !vm.ref<!hal.buffer>
    vm.global.ref private mutable @__auto.blk.117.ffn_norm.weight$5 : !vm.ref<!hal.buffer>
    vm.global.ref private mutable @__auto.blk.117.ffn_norm.weight$6 : !vm.ref<!hal.buffer>
    vm.global.ref private mutable @__auto.blk.117.ffn_norm.weight$7 : !vm.ref<!hal.buffer>
    vm.global.ref private mutable @__auto.blk.117.ffn_gate.weight.shard.0 : !vm.ref<!hal.buffer>
    vm.global.ref private mutable @__auto.blk.117.ffn_gate.weight.shard.1 : !vm.ref<!hal.buffer>
    vm.global.ref private mutable @__auto.blk.117.ffn_gate.weight.shard.2 : !vm.ref<!hal.buffer>
    vm.global.ref private mutable @__auto.blk.117.ffn_gate.weight.shard.3 : !vm.ref<!hal.buffer>
    vm.global.ref private mutable @__auto.blk.117.ffn_gate.weight.shard.4 : !vm.ref<!hal.buffer>
    vm.global.ref private mutable @__auto.blk.117.ffn_gate.weight.shard.5 : !vm.ref<!hal.buffer>
    vm.global.ref private mutable @__auto.blk.117.ffn_gate.weight.shard.6 : !vm.ref<!hal.buffer>
    vm.global.ref private mutable @__auto.blk.117.ffn_gate.weight.shard.7 : !vm.ref<!hal.buffer>

We need to get parameter repacking on in the compiler and/or the gather path (--iree-stream-resource-memory-model=discrete) on for loading. 9100 discrete buffers is a lot to track and putting too much pressure on the system.

      vm.call.variadic @hal.command_buffer.dispatch(%ref_10763, %__device_0_executable_6__initializer_11_dispatch_0, %zero_13, %c64, %c512, %c1, %zero, [%c-533724608, %c9], [(%zero_13, %zero_13, %ref_10337, %zero, %c67108864), (%zero_13, %zero_13, %ref_10759, %zero, %c67648881280)]) : (!vm.ref<!hal.command_buffer>, !vm.ref<!hal.executable>, i32, i32, i32, i32, i64, i32 ..., tuple<i32, i32, !vm.ref<!hal.buffer>, i64, i64> ...)
      vm.call @hal.command_buffer.copy_buffer(%ref_10763, %zero_13, %zero_13, %ref_2750, %c5301600256, %ref_10759, %c42483057216, %c67108864) : (!vm.ref<!hal.command_buffer>, i32, i32, !vm.ref<!hal.buffer>, i64, !vm.ref<!hal.buffer>, i64, i64) -> ()
      vm.call @hal.command_buffer.copy_buffer(%ref_10763, %zero_13, %zero_13, %ref_3894, %c5301600256, %ref_10759, %c42550166080, %c67108864) : (!vm.ref<!hal.command_buffer>, i32, i32, !vm.ref<!hal.buffer>, i64, !vm.ref<!hal.buffer>, i64, i64) -> ()
      vm.call @hal.command_buffer.copy_buffer(%ref_10763, %zero_13, %zero_13, %ref_5038, %c5301600256, %ref_10759, %c42617274944, %c67108864) : (!vm.ref<!hal.command_buffer>, i32, i32, !vm.ref<!hal.buffer>, i64, !vm.ref<!hal.buffer>, i64, i64) -> ()
      vm.call @hal.command_buffer.copy_buffer(%ref_10763, %zero_13, %zero_13, %ref_6182, %c5301600256, %ref_10759, %c42684383808, %c67108864) : (!vm.ref<!hal.command_buffer>, i32, i32, !vm.ref<!hal.buffer>, i64, !vm.ref<!hal.buffer>, i64, i64) -> ()
      vm.call @hal.command_buffer.copy_buffer(%ref_10763, %zero_13, %zero_13, %ref_7326, %c5301600256, %ref_10759, %c42751492672, %c67108864) : (!vm.ref<!hal.command_buffer>, i32, i32, !vm.ref<!hal.buffer>, i64, !vm.ref<!hal.buffer>, i64, i64) -> ()
      vm.call @hal.command_buffer.copy_buffer(%ref_10763, %zero_13, %zero_13, %ref_8470, %c5301600256, %ref_10759, %c42818601536, %c67108864) : (!vm.ref<!hal.command_buffer>, i32, i32, !vm.ref<!hal.buffer>, i64, !vm.ref<!hal.buffer>, i64, i64) -> ()
      vm.call @hal.command_buffer.copy_buffer(%ref_10763, %zero_13, %zero_13, %ref_9614, %c5301600256, %ref_10759, %c42885710400, %c67108864) : (!vm.ref<!hal.command_buffer>, i32, i32, !vm.ref<!hal.buffer>, i64, !vm.ref<!hal.buffer>, i64, i64) -> ()
      vm.call.variadic @hal.command_buffer.dispatch(%ref_10763, %__device_0_executable_6__initializer_11_dispatch_0, %zero_13, %c64, %c512, %c1, %zero, [%c3146304, %c10], [(%zero_13, %zero_13, %ref_10346, %zero, %c67108864), (%zero_13, %zero_13, %ref_10759, %zero, %c67648881280)]) : (!vm.ref<!hal.command_buffer>, !vm.ref<!hal.executable>, i32, i32, i32, i32, i64, i32 ..., tuple<i32, i32, !vm.ref<!hal.buffer>, i64, i64> ...)
      vm.call @hal.command_buffer.copy_buffer(%ref_10763, %zero_13, %zero_13, %ref_2750, %c5368709120, %ref_10759, %c43019928128, %c67108864) : (!vm.ref<!hal.command_buffer>, i32, i32, !vm.ref<!hal.buffer>, i64, !vm.ref<!hal.buffer>, i64, i64) -> ()
      vm.call @hal.command_buffer.copy_buffer(%ref_10763, %zero_13, %zero_13, %ref_3894, %c5368709120, %ref_10759, %c43087036992, %c67108864) : (!vm.ref<!hal.command_buffer>, i32, i32, !vm.ref<!hal.buffer>, i64, !vm.ref<!hal.buffer>, i64, i64) -> ()
      vm.call @hal.command_buffer.copy_buffer(%ref_10763, %zero_13, %zero_13, %ref_5038, %c5368709120, %ref_10759, %c43154145856, %c67108864) : (!vm.ref<!hal.command_buffer>, i32, i32, !vm.ref<!hal.buffer>, i64, !vm.ref<!hal.buffer>, i64, i64) -> ()
      vm.call @hal.command_buffer.copy_buffer(%ref_10763, %zero_13, %zero_13, %ref_6182, %c5368709120, %ref_10759, %c43221254720, %c67108864) : (!vm.ref<!hal.command_buffer>, i32, i32, !vm.ref<!hal.buffer>, i64, !vm.ref<!hal.buffer>, i64, i64) -> ()
      vm.call @hal.command_buffer.copy_buffer(%ref_10763, %zero_13, %zero_13, %ref_7326, %c5368709120, %ref_10759, %c43288363584, %c67108864) : (!vm.ref<!hal.command_buffer>, i32, i32, !vm.ref<!hal.buffer>, i64, !vm.ref<!hal.buffer>, i64, i64) -> ()
      vm.call @hal.command_buffer.copy_buffer(%ref_10763, %zero_13, %zero_13, %ref_8470, %c5368709120, %ref_10759, %c43355472448, %c67108864) : (!vm.ref<!hal.command_buffer>, i32, i32, !vm.ref<!hal.buffer>, i64, !vm.ref<!hal.buffer>, i64, i64) -> ()
      vm.call @hal.command_buffer.copy_buffer(%ref_10763, %zero_13, %zero_13, %ref_9614, %c5368709120, %ref_10759, %c43422581312, %c67108864) : (!vm.ref<!hal.command_buffer>, i32, i32, !vm.ref<!hal.buffer>, i64, !vm.ref<!hal.buffer>, i64, i64) -> ()
      vm.call.variadic @hal.command_buffer.dispatch(%ref_10763, %__device_0_executable_6__initializer_11_dispatch_0, %zero_13, %c64, %c512, %c1, %zero, [%c540017216, %c10], [(%zero_13, %zero_13, %ref_10355, %zero, %c67108864), (%zero_13, %zero_13, %ref_10759, %zero, %c67648881280)]) : (!vm.ref<!hal.command_buffer>, !vm.ref<!hal.executable>, i32, i32, i32, i32, i64, i32 ..., tuple<i32, i32, !vm.ref<!hal.buffer>, i64, i64> ...)
      vm.call @hal.command_buffer.copy_buffer(%ref_10763, %zero_13, %zero_13, %ref_2750, %c5435817984, %ref_10759, %c43556799040, %c67108864) : (!vm.ref<!hal.command_buffer>, i32, i32, !vm.ref<!hal.buffer>, i64, !vm.ref<!hal.buffer>, i64, i64) -> ()
      vm.call @hal.command_buffer.copy_buffer(%ref_10763, %zero_13, %zero_13, %ref_3894, %c5435817984, %ref_10759, %c43623907904, %c67108864) : (!vm.ref<!hal.command_buffer>, i32, i32, !vm.ref<!hal.buffer>, i64, !vm.ref<!hal.buffer>, i64, i64) -> ()
      vm.call @hal.command_buffer.copy_buffer(%ref_10763, %zero_13, %zero_13, %ref_5038, %c5435817984, %ref_10759, %c43691016768, %c67108864) : (!vm.ref<!hal.command_buffer>, i32, i32, !vm.ref<!hal.buffer>, i64, !vm.ref<!hal.buffer>, i64, i64) -> ()
      vm.call @hal.command_buffer.copy_buffer(%ref_10763, %zero_13, %zero_13, %ref_6182, %c5435817984, %ref_10759, %c43758125632, %c67108864) : (!vm.ref<!hal.command_buffer>, i32, i32, !vm.ref<!hal.buffer>, i64, !vm.ref<!hal.buffer>, i64, i64) -> ()
      vm.call @hal.command_buffer.copy_buffer(%ref_10763, %zero_13, %zero_13, %ref_7326, %c5435817984, %ref_10759, %c43825234496, %c67108864) : (!vm.ref<!hal.command_buffer>, i32, i32, !vm.ref<!hal.buffer>, i64, !vm.ref<!hal.buffer>, i64, i64) -> ()
      vm.call @hal.command_buffer.copy_buffer(%ref_10763, %zero_13, %zero_13, %ref_8470, %c5435817984, %ref_10759, %c43892343360, %c67108864) : (!vm.ref<!hal.command_buffer>, i32, i32, !vm.ref<!hal.buffer>, i64, !vm.ref<!hal.buffer>, i64, i64) -> ()
      vm.call @hal.command_buffer.copy_buffer(%ref_10763, %zero_13, %zero_13, %ref_9614, %c5435817984, %ref_10759, %c43959452224, %c67108864) : (!vm.ref<!hal.command_buffer>, i32, i32, !vm.ref<!hal.buffer>, i64, !vm.ref<!hal.buffer>, i64, i64) -> ()

I don't know what's happening during initialization time, but that looks bad (repeated dispatches and copies of the same thing).

      vm.call.variadic @hal.command_buffer.dispatch(%ref_9616, %__device_7_executable_0__initializer_11_dispatch_0, %zero_13, %c64, %c512, %c1, %zero, [%c-1073741824, %c1], [(%zero_13, %zero_13, %ref_9489, %zero, %c67108864), (%zero_13, %zero_13, %ref_9614, %zero, %c8455716864)]) : (!vm.ref<!hal.command_buffer>, !vm.ref<!hal.executable>, i32, i32, i32, i32, i64, i32 ..., tuple<i32, i32, !vm.ref<!hal.buffer>, i64, i64> ...)
      vm.call.variadic @hal.command_buffer.dispatch(%ref_9616, %__device_7_executable_0__initializer_11_dispatch_0, %zero_13, %c64, %c512, %c1, %zero, [%c-1006632960, %c1], [(%zero_13, %zero_13, %ref_9498, %zero, %c67108864), (%zero_13, %zero_13, %ref_9614, %zero, %c8455716864)]) : (!vm.ref<!hal.command_buffer>, !vm.ref<!hal.executable>, i32, i32, i32, i32, i64, i32 ..., tuple<i32, i32, !vm.ref<!hal.buffer>, i64, i64> ...)
      vm.call.variadic @hal.command_buffer.dispatch(%ref_9616, %__device_7_executable_0__initializer_11_dispatch_0, %zero_13, %c64, %c512, %c1, %zero, [%c-939524096, %c1], [(%zero_13, %zero_13, %ref_9507, %zero, %c67108864), (%zero_13, %zero_13, %ref_9614, %zero, %c8455716864)]) : (!vm.ref<!hal.command_buffer>, !vm.ref<!hal.executable>, i32, i32, i32, i32, i64, i32 ..., tuple<i32, i32, !vm.ref<!hal.buffer>, i64, i64> ...)
      vm.call.variadic @hal.command_buffer.dispatch(%ref_9616, %__device_7_executable_0__initializer_11_dispatch_0, %zero_13, %c64, %c512, %c1, %zero, [%c-872415232, %c1], [(%zero_13, %zero_13, %ref_9516, %zero, %c67108864), (%zero_13, %zero_13, %ref_9614, %zero, %c8455716864)]) : (!vm.ref<!hal.command_buffer>, !vm.ref<!hal.executable>, i32, i32, i32, i32, i64, i32 ..., tuple<i32, i32, !vm.ref<!hal.buffer>, i64, i64> ...)
      vm.call.variadic @hal.command_buffer.dispatch(%ref_9616, %__device_7_executable_0__initializer_11_dispatch_0, %zero_13, %c64, %c512, %c1, %zero, [%c-805306368, %c1], [(%zero_13, %zero_13, %ref_9525, %zero, %c67108864), (%zero_13, %zero_13, %ref_9614, %zero, %c8455716864)]) : (!vm.ref<!hal.command_buffer>, !vm.ref<!hal.executable>, i32, i32, i32, i32, i64, i32 ..., tuple<i32, i32, !vm.ref<!hal.buffer>, i64, i64> ...)
      vm.call.variadic @hal.command_buffer.dispatch(%ref_9616, %__device_7_executable_0__initializer_11_dispatch_0, %zero_13, %c64, %c512, %c1, %zero, [%c-738197504, %c1], [(%zero_13, %zero_13, %ref_9534, %zero, %c67108864), (%zero_13, %zero_13, %ref_9614, %zero, %c8455716864)]) : (!vm.ref<!hal.command_buffer>, !vm.ref<!hal.executable>, i32, i32, i32, i32, i64, i32 ..., tuple<i32, i32, !vm.ref<!hal.buffer>, i64, i64> ...)
      vm.call.variadic @hal.command_buffer.dispatch(%ref_9616, %__device_7_executable_0__initializer_11_dispatch_0, %zero_13, %c64, %c512, %c1, %zero, [%c-671088640, %c1], [(%zero_13, %zero_13, %ref_9543, %zero, %c67108864), (%zero_13, %zero_13, %ref_9614, %zero, %c8455716864)]) : (!vm.ref<!hal.command_buffer>, !vm.ref<!hal.executable>, i32, i32, i32, i32, i64, i32 ..., tuple<i32, i32, !vm.ref<!hal.buffer>, i64, i64> ...)
      vm.call.variadic @hal.command_buffer.dispatch(%ref_9616, %__device_7_executable_0__initializer_11_dispatch_0, %zero_13, %c64, %c512, %c1, %zero, [%c-603979776, %c1], [(%zero_13, %zero_13, %ref_9552, %zero, %c67108864), (%zero_13, %zero_13, %ref_9614, %zero, %c8455716864)]) : (!vm.ref<!hal.command_buffer>, !vm.ref<!hal.executable>, i32, i32, i32, i32, i64, i32 ..., tuple<i32, i32, !vm.ref<!hal.buffer>, i64, i64> ...)
      vm.call.variadic @hal.command_buffer.dispatch(%ref_9616, %__device_7_executable_0__initializer_11_dispatch_0, %zero_13, %c64, %c512, %c1, %zero, [%c-536870912, %c1], [(%zero_13, %zero_13, %ref_9561, %zero, %c67108864), (%zero_13, %zero_13, %ref_9614, %zero, %c8455716864)]) : (!vm.ref<!hal.command_buffer>, !vm.ref<!hal.executable>, i32, i32, i32, i32, i64, i32 ..., tuple<i32, i32, !vm.ref<!hal.buffer>, i64, i64> ...)
      vm.call.variadic @hal.command_buffer.dispatch(%ref_9616, %__device_7_executable_0__initializer_11_dispatch_0, %zero_13, %c64, %c512, %c1, %zero, [%c-469762048, %c1], [(%zero_13, %zero_13, %ref_9570, %zero, %c67108864), (%zero_13, %zero_13, %ref_9614, %zero, %c8455716864)]) : (!vm.ref<!hal.command_buffer>, !vm.ref<!hal.executable>, i32, i32, i32, i32, i64, i32 ..., tuple<i32, i32, !vm.ref<!hal.buffer>, i64, i64> ...)
      vm.call.variadic @hal.command_buffer.dispatch(%ref_9616, %__device_7_executable_0__initializer_11_dispatch_0, %zero_13, %c64, %c512, %c1, %zero, [%c-402653184, %c1], [(%zero_13, %zero_13, %ref_9579, %zero, %c67108864), (%zero_13, %zero_13, %ref_9614, %zero, %c8455716864)]) : (!vm.ref<!hal.command_buffer>, !vm.ref<!hal.executable>, i32, i32, i32, i32, i64, i32 ..., tuple<i32, i32, !vm.ref<!hal.buffer>, i64, i64> ...)
      vm.call.variadic @hal.command_buffer.dispatch(%ref_9616, %__device_7_executable_0__initializer_11_dispatch_0, %zero_13, %c64, %c512, %c1, %zero, [%c-335544320, %c1], [(%zero_13, %zero_13, %ref_9588, %zero, %c67108864), (%zero_13, %zero_13, %ref_9614, %zero, %c8455716864)]) : (!vm.ref<!hal.command_buffer>, !vm.ref<!hal.executable>, i32, i32, i32, i32, i64, i32 ..., tuple<i32, i32, !vm.ref<!hal.buffer>, i64, i64> ...)
      vm.call.variadic @hal.command_buffer.dispatch(%ref_9616, %__device_7_executable_0__initializer_11_dispatch_0, %zero_13, %c64, %c512, %c1, %zero, [%c-268435456, %c1], [(%zero_13, %zero_13, %ref_9597, %zero, %c67108864), (%zero_13, %zero_13, %ref_9614, %zero, %c8455716864)]) : (!vm.ref<!hal.command_buffer>, !vm.ref<!hal.executable>, i32, i32, i32, i32, i64, i32 ..., tuple<i32, i32, !vm.ref<!hal.buffer>, i64, i64> ...)

Also initialization time - that's probably bad unless it's initializing variables. Something those are touching is 8455716864 bytes, though, which is a lot.

benvanik commented 3 weeks ago

Also:

      vm.call @hal.device.queue.dealloca(%__device_1, %c-1, %null, %ref_9340, %ref_9336) : (!vm.ref<!hal.device>, i64, !vm.ref<!hal.fence>, !vm.ref<!hal.fence>, !vm.ref<!hal.buffer>) -> ()
      %8884 = vm.call.variadic @hal.fence.await(%c-1_59, [%ref_9340]) : (i32, !vm.ref<!hal.fence> ...) -> i32
      %ref_9341 = vm.call.variadic @hal.fence.join([%ref_9277, %ref_9276, %ref_9340, %ref_9339, %ref_9295, %ref_9294, %ref_9304, %ref_9303, %ref_9313, %ref_9312, %ref_9322, %ref_9321, %ref_9331, %ref_9330, %ref_9286, %ref_9285, %ref_100]) {nosideeffects} : (!vm.ref<!hal.fence> ...) -> !vm.ref<!hal.fence>
      %ref_9342 = vm.call @hal.fence.create(%__device_2, %zero) : (!vm.ref<!hal.device>, i32) -> !vm.ref<!hal.fence>
      %8885 = vm.call.variadic @hal.fence.await(%c-1_59, [%ref_9341]) : (i32, !vm.ref<!hal.fence> ...) -> i32
      %ref_9343 = vm.call @hal.device.queue.alloca(%__device_2, %c-1, %null, %ref_9342, %zero, %c48, %c3075, %c149504_31) : (!vm.ref<!hal.device>, i64, !vm.ref<!hal.fence>, !vm.ref<!hal.fence>, i32, i32, i32, i64) -> !vm.ref<!hal.buffer>
      %8886 = vm.call.variadic @hal.fence.await(%c-1_59, [%ref_9342]) : (i32, !vm.ref<!hal.fence> ...) -> i32
      %ref_9344 = vm.call @hal.fence.create(%__device_2, %zero) : (!vm.ref<!hal.device>, i32) -> !vm.ref<!hal.fence>
      %8887 = vm.call.variadic @hal.fence.await(%c-1_59, [%ref_9341]) : (i32, !vm.ref<!hal.fence> ...) -> i32
      %ref_9345 = vm.call @hal.device.queue.alloca(%__device_2, %c-1, %null, %ref_9344, %zero, %c48, %c3075, %c917504_49) : (!vm.ref<!hal.device>, i64, !vm.ref<!hal.fence>, !vm.ref<!hal.fence>, i32, i32, i32, i64) -> !vm.ref<!hal.buffer>
      %8888 = vm.call.variadic @hal.fence.await(%c-1_59, [%ref_9344]) : (i32, !vm.ref<!hal.fence> ...) -> i32

Those fence awaits are bad and probably being set by --iree-hip-legacy-sync=true which is still the default. It shouldn't be. We need to remove that flag entirely.

benvanik commented 3 weeks ago

I'd assumed we were hitting the i32 register max, but we're hitting the ref register max (which is a much lower 16383 - because no reasonable program should have 16k live ref counted values in the same function). Fixing executables and parameters will drop that dramatically (from ~1600+9100=~10k to a few dozen). Reusable command buffers/memoization should then make functions that only use a fraction of the total resources and take it down to a handful per function.

stellaraccident commented 3 weeks ago

Ok, so need to find a way to thread the needle while we burn down the debt.

Looks like we can immediately try:

Things that require some coding:

benvanik commented 3 weeks ago

Both flags should work today with HIP - they're emulated, but way more efficient than not memoizing them. Each command buffer (stream execution region) gets outlined to its own function. You will have to remove that legacy sync flag (we should delete that - if things don't work with it off then those should be P0).

stellaraccident commented 3 weeks ago

Legacy sync flag (still defaults to true for rocm :(): https://github.com/iree-org/iree/blob/main/compiler/plugins/target/ROCM/ROCMTarget.cpp#L103

benvanik commented 3 weeks ago

Linking implemented in #18936 - it should help reduce the executable ref count from ~1600 to 8 (one for each device).

KyleHerndon commented 3 weeks ago

Without yet using your changes in #18936, I compiled by adding the following parameters --iree-stream-resource-memory-model=discrete --iree-hip-legacy-sync=False And got the following error when running the output: iree/runtime/src/iree/vm/bytecode/archive.c:122: INVALID_ARGUMENT; FlatBuffer length prefix out of bounds (prefix is 1969516397 but only 70499378 available); loading bytecode module at '405b_f16_tp8.vmfb'; loading modules and dependencies; creating run context

benvanik commented 3 weeks ago

That may indicate a corrupt VMFB.

KyleHerndon commented 3 weeks ago

Ah, I found my mistake. I thought --compile-to=vm would produce the vmfb, but it seems it is one step before it. Tried compiled the rest of the way, but the compiler just seems to be crashing with this error:

LLVM ERROR: can't create Attribute 'mlir::iree_compiler::IREE::GPU::MMAIntrinsicAttr' because storage uniquer isn't initialized: the dialect was likely not loaded, or the attribute wasn't added with addAttributes<...>() in the Dialect::initialize() method.
KyleHerndon commented 3 weeks ago

The same error occurs when I run with all of the flags: --iree-stream-resource-memory-model=discrete --iree-hip-legacy-sync=False --iree-hal-indirect-command-buffers=true --iree-hal-memoization=true Would uploading a vmfb or any of the compilation steps be helpful?

benvanik commented 3 weeks ago

vmfbs aren't useful, but the vm IR is (--compile-to=vm) as well as the stream IR (--compile-to=stream).

KyleHerndon commented 3 weeks ago

Intermediate MLIRs for compile command: iree-build/tools/iree-compile ./405b_f16_tp8.mlir --iree-hal-target-device=hip[0] --iree-hal-target-device=hip[1] --iree-hal-target-device=hip[2] --iree-hal-target-device=hip[3] --iree-hal-target-device=hip[4] --iree-hal-target-device=hip[5] --iree-hal-target-device=hip[6] --iree-hal-target-device=hip[7] --iree-hal-executable-debug-level=3 --iree-hip-target=gfx942 -o 405b_f16_tp8.vmfb --iree-stream-resource-memory-model=discrete --iree-hip-legacy-sync=False --iree-hal-indirect-command-buffers=true --iree-hal-memoization=true MLIRs: Stream vm

benvanik commented 3 weeks ago

Cool - I'll compile down and see if I can spot anything and use #18936 (which should help a lot).

Still lots of suspicious things in here. Are we not fusing things with matmuls? Seeing batch_matmul dispatches followed by elementwise addition dispatches.

benvanik commented 3 weeks ago

A large number of dispatches in here are slow_memcpy. There are literally 15128 slow_mempy dispatches. 🙅🙅🙅🙅🙅🙅🙅🙅

        %18 = flow.dispatch.tensor.load %16, offsets = [0, 0, 0], sizes = [4, %15, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<4x?x128xf16>>{%15} -> tensor<4x?x128xf16>
        flow.dispatch.tensor.store %18, %17, offsets = [0, 0, 0, 0], sizes = [4, %15, 1, 128], strides = [1, 1, 1, 1] : tensor<4x?x128xf16> -> !flow.dispatch.tensor<readwrite:tensor<4x?x1x128xf16>>{%15}

looks like most are just adding a unit dimension.

benvanik commented 3 weeks ago

18945 has a new flag that should be used instead of --iree-hal-indirect-command-buffers=true --iree-hal-memoization=true (as those are on by default now): --iree-hal-force-indirect-command-buffers=true`. This should yield a program with a bunch of functions.

KyleHerndon commented 3 weeks ago

Retried from commit 12cb042b3ef4d4c16aab9fe232d1ff6c5a9e9888 and I'm still running into the MMAIntrinsicAttr error. Updated MLIR: stream vm

benvanik commented 3 weeks ago

can you also post flow? thx!

KyleHerndon commented 3 weeks ago

flow

benvanik commented 3 weeks ago

nice

I was fearing something like this:

    %60411 = flow.tensor.splat %c79_i64 : tensor<4x2xi64>
    %60412 = flow.tensor.transfer %60411 : tensor<4x2xi64> to #hal.device.affinity<@__device_0>
    %60413 = flow.tensor.transfer %60411 : tensor<4x2xi64> to #hal.device.affinity<@__device_1>
    %60414 = flow.tensor.transfer %60411 : tensor<4x2xi64> to #hal.device.affinity<@__device_2>
    %60415 = flow.tensor.transfer %60411 : tensor<4x2xi64> to #hal.device.affinity<@__device_3>
    %60416 = flow.tensor.transfer %60411 : tensor<4x2xi64> to #hal.device.affinity<@__device_4>
    %60417 = flow.tensor.transfer %60411 : tensor<4x2xi64> to #hal.device.affinity<@__device_5>
    %60418 = flow.tensor.transfer %60411 : tensor<4x2xi64> to #hal.device.affinity<@__device_6>
    %60419 = flow.tensor.transfer %60411 : tensor<4x2xi64> to #hal.device.affinity<@__device_7>

Need to have a flow.tensor.transfer canonicalizer pattern to always fold away splats, but we may want something earlier than that.

this is leading to some of the badness we have with unfused slow memcpys:

    %52801 = flow.tensor.splat %c69_i64 : tensor<4x2xi64>
    %52802 = flow.tensor.transfer %52801 : tensor<4x2xi64> to #hal.device.affinity<@__device_0>
    %52803 = flow.tensor.transfer %52801 : tensor<4x2xi64> to #hal.device.affinity<@__device_1>
    %52804 = flow.tensor.transfer %52801 : tensor<4x2xi64> to #hal.device.affinity<@__device_2>
    %52805 = flow.tensor.transfer %52801 : tensor<4x2xi64> to #hal.device.affinity<@__device_3>
    %52806 = flow.tensor.transfer %52801 : tensor<4x2xi64> to #hal.device.affinity<@__device_4>
    %52807 = flow.tensor.transfer %52801 : tensor<4x2xi64> to #hal.device.affinity<@__device_5>
    %52808 = flow.tensor.transfer %52801 : tensor<4x2xi64> to #hal.device.affinity<@__device_6>
    %52809 = flow.tensor.transfer %52801 : tensor<4x2xi64> to #hal.device.affinity<@__device_7>
    %52810 = flow.tensor.reshape %52802 : tensor<4x2xi64> -> tensor<8xi64>
    %52811 = flow.dispatch @decode_bs4$async_dispatch_108::@decode_bs4$async_dispatch_108_slow_memcpy(%52810, %259) : (tensor<8xi64>, tensor<8x4xi64>) -> %259
    %52812 = flow.dispatch @decode_bs4$async_dispatch_109::@decode_bs4$async_dispatch_109_slow_memcpy(%256, %52811) : (tensor<8xi64>, tensor<8x4xi64>) -> %52811
    %52813 = flow.dispatch @decode_bs4$async_dispatch_110::@decode_bs4$async_dispatch_110_slow_memcpy(%262, %52812) : (tensor<8xi64>, tensor<8x4xi64>) -> %52812

That %c69 should just be an i64 operand to those dispatches, and I bet those dispatches may never have been formed.

A flow.tensor.transfer canonicalizer on flow.tensor.splat would help later on, but we'd need to make sure it happened on the input prior to dispatch region formation too (maybe as part of global opt).

benvanik commented 3 weeks ago

it'd also be useful to get the model exported from torch with @stellaraccident's assertions on shape dimensions

benvanik commented 3 weeks ago

To speed up compilation we should probably get around to folding the flow.tensor.reshapes into the dispatch regions - when we have 200+k ops and ~70k of them are flow.tensor.reshape that mean nothing we could speed up some stages of the compiler here. May hurt deduplication, though. Hm.

benvanik commented 3 weeks ago

Unelided timepoints create a lot of IR and a lot of runtime references that are long-lived. Found the issue and filed #18960. I ran with a 64k PVS capacity - it was slow but did mostly resolve all of the covered timepoints. There's still many remaining joins but that's because it's sharded 8 ways and we'd expect a lot of joins of 8 timepoints.

Before/after:

%12691 = stream.timepoint.join max(%12665, %12664, %12641, %12640, %12645, %12644, %12649, %12648, %12653, %12652, %12657, %12656, %12661, %12660, %12637, %12636, %12669, %12668, %12673, %12672, %12677, %12676, %12681, %12680, %12685, %12684, %12689, %12688, %1, %__auto.token_embd.weight__timepoint, %result_timepoint_12897, %result_timepoint_12899) => !stream.timepoint
>
%10812 = stream.timepoint.join max(%result_timepoint_12897, %result_timepoint_12899) => !stream.timepoint
%6153 = stream.resource.dealloca on(#hal.device.affinity<@__device_1>) await(%6152) => %result_5604 : !stream.resource<transient>{%444} => !stream.timepoint
%6154 = stream.timepoint.join max(%6125, %6124, %6153, %6152, %6133, %6132, %6137, %6136, %6141, %6140, %6145, %6144, %6149, %6148, %6129, %6128, %192) => !stream.timepoint
%result_5606, %result_timepoint_5607 = stream.resource.alloca uninitialized on(#hal.device.affinity<@__device_2>) await(%6154) => !stream.resource<transient>{%514} => !stream.timepoint
%result_5608, %result_timepoint_5609 = stream.resource.alloca uninitialized on(#hal.device.affinity<@__device_2>) await(%6154) => !stream.resource<transient>{%525} => !stream.timepoint
%6155 = stream.timepoint.join max(%6125, %6124, %6153, %6152, %6133, %6132, %6137, %6136, %6141, %6140, %6145, %6144, %6149, %6148, %6129, %6128, %192, %result_timepoint_5607, %result_timepoint_5609) => !stream.timepoint
%6156 = stream.cmd.execute on(#hal.device.affinity<@__device_2>) await(%6155) ...
>
%4995 = stream.resource.dealloca on(#hal.device.affinity<@__device_1>) await(%4994) => %result_5604 : !stream.resource<transient>{%444} => !stream.timepoint
%result_5606, %result_timepoint_5607 = stream.resource.alloca uninitialized on(#hal.device.affinity<@__device_2>) await(%4995) => !stream.resource<transient>{%514} => !stream.timepoint
%result_5608, %result_timepoint_5609 = stream.resource.alloca uninitialized on(#hal.device.affinity<@__device_2>) await(%4995) => !stream.resource<transient>{%525} => !stream.timepoint
%4996 = stream.timepoint.join max(%result_timepoint_5607, %result_timepoint_5609) => !stream.timepoint
%4997 = stream.cmd.execute on(#hal.device.affinity<@__device_2>) await(%4996) ...

So definitely a lot of room for improvement when that pass is rewritten.

KyleHerndon commented 3 weeks ago

If I understand correctly, this version might include shape assertions because it includes this patch during export. Original Flow Stream VM

benvanik commented 3 weeks ago

Nice, that includes them! Image

(note that the workload ordinal has a bounded range there now - other dispatches now also have divisibility, e.g. %14<umin = 0, umax = 9007199254740991> : index -> %10<umin = 16, umax = 131056, udiv = 16>)