Open KyleHerndon opened 4 weeks ago
@benvanik I think this was one of those "if you ever need this, we want to take a close look" kinds of things. With that said, we to a program that was already large and multiplied it by 8...
That means it's hitting 65k live values. There's no going bigger than that - just making the compiler be better. Best bet is to split each thing into its own function instead of inlining it all into one.
(a --compile-to=vm dump would be useful in seeing how the hell there are 65k live values - it may just be a bug/pathological behavior in the crappy register allocator I begrudgingly wrote 5 years ago :)
Ah this is live values. That is a lot. My bet is on something that has been pathological for a while and then we just multiplied it by 8.
Yeah, maybe it's got like 40k unique constants that are used multiple times so their lifetimes overlap. I've got a TODO for making constants better but still would not be great (no compiler/JIT/etc ever wants to see a function with tens of thousands of locals :). We need the compiler to spill (e.g. to a temporary vm.list acting as the stack) or rematerialize (clone constants near uses). The former sucks for performance but if the latter was the issue it'd be ok.
(having one non-inlinable function per shard is going to be the best, though - you'll get way faster compiles too for things nested on function scope)
That would be best but it's kind of the opposite of how any frontend I know of will want to split things.
It can be something we do near the frontend/global opt (affinity analysis -> slice -> outline).
MLIR/linalg/etc is not built for big things -- and in some cases not even medium ones -- and if we want reasonable compile times we'll need to do something other than chew on bajillion op functions - if the frontend can simplify things that's best but otherwise we need to :(
Yeah, that's what I meant... Need to partition on some axis
But we need to look at this specific case. In the optimized single device cases I was producing, things weren't crazy. My vote is on pathological machinery somewhere
Yeah --compile-to=vm will be the best place to start - I bet there's some low hanging fruit.
Here is IR (as bytecode) of /home/stella/tmp/llamabig/llama3.1_405b_f16_tp8.mlir.stream : https://drive.google.com/file/d/1Y_G7helvJ15R8DVcVkcvbqnzBRzVTiBo/view?usp=sharing That's as far as I got tonight. But I was poking at things. The profile report is 38GiB.
--compile-to=vm IR: https://drive.google.com/file/d/1X27ro1iQEyS-Kvp5HEqUBoC8gnfyNNX4/view?usp=sharing
This phase took a mere 13m. A quick look at the profile shows almost all of that is in repeated application of two patterns in a canonicalizer which should probably be a pass.
Compilation to stream took 56m. There was a more varied assortment of bad things -- almost all of them canonicalization.
I'll need to write some tools to make sense of these profiles.
Taking a peek.
Quick suggestion is to enable command buffer reuse/memoization:
--iree-hal-indirect-command-buffers=true --iree-hal-memoization=true
It doesn't handle dynamic cases well but most of this is not dynamic and may be able to be helped by at least the outlining. That will be the default once we are off HIP (off now because we can't do indirect dispatch there and support reuse in general).
--iree-stream-resource-memory-model=discrete
is likely to help too. That should be the default.
This is garbage-shaped, though, so there will be issues. Lots of them look low-hanging thankfully.
First sadness:
vm.global.ref private mutable @__device_3_executable_7_prefill_bs4$async_dispatch_51 : !vm.ref<!hal.executable>
vm.global.ref private mutable @__device_3_executable_8_prefill_bs4$async_dispatch_67 : !vm.ref<!hal.executable>
vm.global.ref private mutable @__device_3_executable_9_prefill_bs4$async_dispatch_107 : !vm.ref<!hal.executable>
vm.global.ref private mutable @__device_3_executable_10_prefill_bs4$async_dispatch_115 : !vm.ref<!hal.executable>
vm.global.ref private mutable @__device_3_executable_11_prefill_bs4$async_dispatch_123 : !vm.ref<!hal.executable>
vm.global.ref private mutable @__device_3_executable_12_prefill_bs4$async_dispatch_131 : !vm.ref<!hal.executable>
vm.global.ref private mutable @__device_3_executable_13_prefill_bs4$async_dispatch_132 : !vm.ref<!hal.executable>
vm.global.ref private mutable @__device_3_executable_14_prefill_bs4$async_dispatch_147 : !vm.ref<!hal.executable>
vm.global.ref private mutable @__device_3_executable_15_prefill_bs4$async_dispatch_155 : !vm.ref<!hal.executable>
vm.global.ref private mutable @__device_3_executable_16_prefill_bs4$async_dispatch_156 : !vm.ref<!hal.executable>
vm.global.ref private mutable @__device_3_executable_17_prefill_bs4$async_dispatch_171 : !vm.ref<!hal.executable>
vm.global.ref private mutable @__device_3_executable_18_prefill_bs4$async_dispatch_179 : !vm.ref<!hal.executable>
vm.global.ref private mutable @__device_3_executable_19_prefill_bs4$async_dispatch_180 : !vm.ref<!hal.executable>
vm.global.ref private mutable @__device_3_executable_20_prefill_bs4$async_dispatch_195 : !vm.ref<!hal.executable>
vm.global.ref private mutable @__device_3_executable_21_prefill_bs4$async_dispatch_211 : !vm.ref<!hal.executable>
vm.global.ref private mutable @__device_3_executable_22_prefill_bs4$async_dispatch_227 : !vm.ref<!hal.executable>
vm.global.ref private mutable @__device_3_executable_23_prefill_bs4$async_dispatch_235 : !vm.ref<!hal.executable>
vm.global.ref private mutable @__device_3_executable_24_prefill_bs4$async_dispatch_75 : !vm.ref<!hal.executable>
vm.global.ref private mutable @__device_3_executable_25_prefill_bs4$async_dispatch_83 : !vm.ref<!hal.executable>
vm.global.ref private mutable @__device_3_executable_26_prefill_bs4$async_dispatch_91 : !vm.ref<!hal.executable>
vm.global.ref private mutable @__device_3_executable_27_prefill_bs4$async_dispatch_92 : !vm.ref<!hal.executable>
...
*We need to get linking fixed in the ROCM target. 200 executables 8 devices vs 8 total.** This needs attention soon - have been kicking that can since no one wanted to do it for CUDA years ago. Maintaining 1600 executables is not cheap.
Second sadness:
vm.global.ref private mutable @__auto.blk.117.attn_v.weight.shard.3 : !vm.ref<!hal.buffer>
vm.global.ref private mutable @__auto.blk.117.attn_v.weight.shard.4 : !vm.ref<!hal.buffer>
vm.global.ref private mutable @__auto.blk.117.attn_v.weight.shard.5 : !vm.ref<!hal.buffer>
vm.global.ref private mutable @__auto.blk.117.attn_v.weight.shard.6 : !vm.ref<!hal.buffer>
vm.global.ref private mutable @__auto.blk.117.attn_v.weight.shard.7 : !vm.ref<!hal.buffer>
vm.global.ref private mutable @__auto.blk.117.attn_output.weight.shard.0 : !vm.ref<!hal.buffer>
vm.global.ref private mutable @__auto.blk.117.attn_output.weight.shard.1 : !vm.ref<!hal.buffer>
vm.global.ref private mutable @__auto.blk.117.attn_output.weight.shard.2 : !vm.ref<!hal.buffer>
vm.global.ref private mutable @__auto.blk.117.attn_output.weight.shard.3 : !vm.ref<!hal.buffer>
vm.global.ref private mutable @__auto.blk.117.attn_output.weight.shard.4 : !vm.ref<!hal.buffer>
vm.global.ref private mutable @__auto.blk.117.attn_output.weight.shard.5 : !vm.ref<!hal.buffer>
vm.global.ref private mutable @__auto.blk.117.attn_output.weight.shard.6 : !vm.ref<!hal.buffer>
vm.global.ref private mutable @__auto.blk.117.attn_output.weight.shard.7 : !vm.ref<!hal.buffer>
vm.global.ref private mutable @__auto.blk.117.ffn_norm.weight : !vm.ref<!hal.buffer>
vm.global.ref private mutable @__auto.blk.117.ffn_norm.weight$1 : !vm.ref<!hal.buffer>
vm.global.ref private mutable @__auto.blk.117.ffn_norm.weight$2 : !vm.ref<!hal.buffer>
vm.global.ref private mutable @__auto.blk.117.ffn_norm.weight$3 : !vm.ref<!hal.buffer>
vm.global.ref private mutable @__auto.blk.117.ffn_norm.weight$4 : !vm.ref<!hal.buffer>
vm.global.ref private mutable @__auto.blk.117.ffn_norm.weight$5 : !vm.ref<!hal.buffer>
vm.global.ref private mutable @__auto.blk.117.ffn_norm.weight$6 : !vm.ref<!hal.buffer>
vm.global.ref private mutable @__auto.blk.117.ffn_norm.weight$7 : !vm.ref<!hal.buffer>
vm.global.ref private mutable @__auto.blk.117.ffn_gate.weight.shard.0 : !vm.ref<!hal.buffer>
vm.global.ref private mutable @__auto.blk.117.ffn_gate.weight.shard.1 : !vm.ref<!hal.buffer>
vm.global.ref private mutable @__auto.blk.117.ffn_gate.weight.shard.2 : !vm.ref<!hal.buffer>
vm.global.ref private mutable @__auto.blk.117.ffn_gate.weight.shard.3 : !vm.ref<!hal.buffer>
vm.global.ref private mutable @__auto.blk.117.ffn_gate.weight.shard.4 : !vm.ref<!hal.buffer>
vm.global.ref private mutable @__auto.blk.117.ffn_gate.weight.shard.5 : !vm.ref<!hal.buffer>
vm.global.ref private mutable @__auto.blk.117.ffn_gate.weight.shard.6 : !vm.ref<!hal.buffer>
vm.global.ref private mutable @__auto.blk.117.ffn_gate.weight.shard.7 : !vm.ref<!hal.buffer>
We need to get parameter repacking on in the compiler and/or the gather path (--iree-stream-resource-memory-model=discrete
) on for loading. 9100 discrete buffers is a lot to track and putting too much pressure on the system.
vm.call.variadic @hal.command_buffer.dispatch(%ref_10763, %__device_0_executable_6__initializer_11_dispatch_0, %zero_13, %c64, %c512, %c1, %zero, [%c-533724608, %c9], [(%zero_13, %zero_13, %ref_10337, %zero, %c67108864), (%zero_13, %zero_13, %ref_10759, %zero, %c67648881280)]) : (!vm.ref<!hal.command_buffer>, !vm.ref<!hal.executable>, i32, i32, i32, i32, i64, i32 ..., tuple<i32, i32, !vm.ref<!hal.buffer>, i64, i64> ...)
vm.call @hal.command_buffer.copy_buffer(%ref_10763, %zero_13, %zero_13, %ref_2750, %c5301600256, %ref_10759, %c42483057216, %c67108864) : (!vm.ref<!hal.command_buffer>, i32, i32, !vm.ref<!hal.buffer>, i64, !vm.ref<!hal.buffer>, i64, i64) -> ()
vm.call @hal.command_buffer.copy_buffer(%ref_10763, %zero_13, %zero_13, %ref_3894, %c5301600256, %ref_10759, %c42550166080, %c67108864) : (!vm.ref<!hal.command_buffer>, i32, i32, !vm.ref<!hal.buffer>, i64, !vm.ref<!hal.buffer>, i64, i64) -> ()
vm.call @hal.command_buffer.copy_buffer(%ref_10763, %zero_13, %zero_13, %ref_5038, %c5301600256, %ref_10759, %c42617274944, %c67108864) : (!vm.ref<!hal.command_buffer>, i32, i32, !vm.ref<!hal.buffer>, i64, !vm.ref<!hal.buffer>, i64, i64) -> ()
vm.call @hal.command_buffer.copy_buffer(%ref_10763, %zero_13, %zero_13, %ref_6182, %c5301600256, %ref_10759, %c42684383808, %c67108864) : (!vm.ref<!hal.command_buffer>, i32, i32, !vm.ref<!hal.buffer>, i64, !vm.ref<!hal.buffer>, i64, i64) -> ()
vm.call @hal.command_buffer.copy_buffer(%ref_10763, %zero_13, %zero_13, %ref_7326, %c5301600256, %ref_10759, %c42751492672, %c67108864) : (!vm.ref<!hal.command_buffer>, i32, i32, !vm.ref<!hal.buffer>, i64, !vm.ref<!hal.buffer>, i64, i64) -> ()
vm.call @hal.command_buffer.copy_buffer(%ref_10763, %zero_13, %zero_13, %ref_8470, %c5301600256, %ref_10759, %c42818601536, %c67108864) : (!vm.ref<!hal.command_buffer>, i32, i32, !vm.ref<!hal.buffer>, i64, !vm.ref<!hal.buffer>, i64, i64) -> ()
vm.call @hal.command_buffer.copy_buffer(%ref_10763, %zero_13, %zero_13, %ref_9614, %c5301600256, %ref_10759, %c42885710400, %c67108864) : (!vm.ref<!hal.command_buffer>, i32, i32, !vm.ref<!hal.buffer>, i64, !vm.ref<!hal.buffer>, i64, i64) -> ()
vm.call.variadic @hal.command_buffer.dispatch(%ref_10763, %__device_0_executable_6__initializer_11_dispatch_0, %zero_13, %c64, %c512, %c1, %zero, [%c3146304, %c10], [(%zero_13, %zero_13, %ref_10346, %zero, %c67108864), (%zero_13, %zero_13, %ref_10759, %zero, %c67648881280)]) : (!vm.ref<!hal.command_buffer>, !vm.ref<!hal.executable>, i32, i32, i32, i32, i64, i32 ..., tuple<i32, i32, !vm.ref<!hal.buffer>, i64, i64> ...)
vm.call @hal.command_buffer.copy_buffer(%ref_10763, %zero_13, %zero_13, %ref_2750, %c5368709120, %ref_10759, %c43019928128, %c67108864) : (!vm.ref<!hal.command_buffer>, i32, i32, !vm.ref<!hal.buffer>, i64, !vm.ref<!hal.buffer>, i64, i64) -> ()
vm.call @hal.command_buffer.copy_buffer(%ref_10763, %zero_13, %zero_13, %ref_3894, %c5368709120, %ref_10759, %c43087036992, %c67108864) : (!vm.ref<!hal.command_buffer>, i32, i32, !vm.ref<!hal.buffer>, i64, !vm.ref<!hal.buffer>, i64, i64) -> ()
vm.call @hal.command_buffer.copy_buffer(%ref_10763, %zero_13, %zero_13, %ref_5038, %c5368709120, %ref_10759, %c43154145856, %c67108864) : (!vm.ref<!hal.command_buffer>, i32, i32, !vm.ref<!hal.buffer>, i64, !vm.ref<!hal.buffer>, i64, i64) -> ()
vm.call @hal.command_buffer.copy_buffer(%ref_10763, %zero_13, %zero_13, %ref_6182, %c5368709120, %ref_10759, %c43221254720, %c67108864) : (!vm.ref<!hal.command_buffer>, i32, i32, !vm.ref<!hal.buffer>, i64, !vm.ref<!hal.buffer>, i64, i64) -> ()
vm.call @hal.command_buffer.copy_buffer(%ref_10763, %zero_13, %zero_13, %ref_7326, %c5368709120, %ref_10759, %c43288363584, %c67108864) : (!vm.ref<!hal.command_buffer>, i32, i32, !vm.ref<!hal.buffer>, i64, !vm.ref<!hal.buffer>, i64, i64) -> ()
vm.call @hal.command_buffer.copy_buffer(%ref_10763, %zero_13, %zero_13, %ref_8470, %c5368709120, %ref_10759, %c43355472448, %c67108864) : (!vm.ref<!hal.command_buffer>, i32, i32, !vm.ref<!hal.buffer>, i64, !vm.ref<!hal.buffer>, i64, i64) -> ()
vm.call @hal.command_buffer.copy_buffer(%ref_10763, %zero_13, %zero_13, %ref_9614, %c5368709120, %ref_10759, %c43422581312, %c67108864) : (!vm.ref<!hal.command_buffer>, i32, i32, !vm.ref<!hal.buffer>, i64, !vm.ref<!hal.buffer>, i64, i64) -> ()
vm.call.variadic @hal.command_buffer.dispatch(%ref_10763, %__device_0_executable_6__initializer_11_dispatch_0, %zero_13, %c64, %c512, %c1, %zero, [%c540017216, %c10], [(%zero_13, %zero_13, %ref_10355, %zero, %c67108864), (%zero_13, %zero_13, %ref_10759, %zero, %c67648881280)]) : (!vm.ref<!hal.command_buffer>, !vm.ref<!hal.executable>, i32, i32, i32, i32, i64, i32 ..., tuple<i32, i32, !vm.ref<!hal.buffer>, i64, i64> ...)
vm.call @hal.command_buffer.copy_buffer(%ref_10763, %zero_13, %zero_13, %ref_2750, %c5435817984, %ref_10759, %c43556799040, %c67108864) : (!vm.ref<!hal.command_buffer>, i32, i32, !vm.ref<!hal.buffer>, i64, !vm.ref<!hal.buffer>, i64, i64) -> ()
vm.call @hal.command_buffer.copy_buffer(%ref_10763, %zero_13, %zero_13, %ref_3894, %c5435817984, %ref_10759, %c43623907904, %c67108864) : (!vm.ref<!hal.command_buffer>, i32, i32, !vm.ref<!hal.buffer>, i64, !vm.ref<!hal.buffer>, i64, i64) -> ()
vm.call @hal.command_buffer.copy_buffer(%ref_10763, %zero_13, %zero_13, %ref_5038, %c5435817984, %ref_10759, %c43691016768, %c67108864) : (!vm.ref<!hal.command_buffer>, i32, i32, !vm.ref<!hal.buffer>, i64, !vm.ref<!hal.buffer>, i64, i64) -> ()
vm.call @hal.command_buffer.copy_buffer(%ref_10763, %zero_13, %zero_13, %ref_6182, %c5435817984, %ref_10759, %c43758125632, %c67108864) : (!vm.ref<!hal.command_buffer>, i32, i32, !vm.ref<!hal.buffer>, i64, !vm.ref<!hal.buffer>, i64, i64) -> ()
vm.call @hal.command_buffer.copy_buffer(%ref_10763, %zero_13, %zero_13, %ref_7326, %c5435817984, %ref_10759, %c43825234496, %c67108864) : (!vm.ref<!hal.command_buffer>, i32, i32, !vm.ref<!hal.buffer>, i64, !vm.ref<!hal.buffer>, i64, i64) -> ()
vm.call @hal.command_buffer.copy_buffer(%ref_10763, %zero_13, %zero_13, %ref_8470, %c5435817984, %ref_10759, %c43892343360, %c67108864) : (!vm.ref<!hal.command_buffer>, i32, i32, !vm.ref<!hal.buffer>, i64, !vm.ref<!hal.buffer>, i64, i64) -> ()
vm.call @hal.command_buffer.copy_buffer(%ref_10763, %zero_13, %zero_13, %ref_9614, %c5435817984, %ref_10759, %c43959452224, %c67108864) : (!vm.ref<!hal.command_buffer>, i32, i32, !vm.ref<!hal.buffer>, i64, !vm.ref<!hal.buffer>, i64, i64) -> ()
I don't know what's happening during initialization time, but that looks bad (repeated dispatches and copies of the same thing).
vm.call.variadic @hal.command_buffer.dispatch(%ref_9616, %__device_7_executable_0__initializer_11_dispatch_0, %zero_13, %c64, %c512, %c1, %zero, [%c-1073741824, %c1], [(%zero_13, %zero_13, %ref_9489, %zero, %c67108864), (%zero_13, %zero_13, %ref_9614, %zero, %c8455716864)]) : (!vm.ref<!hal.command_buffer>, !vm.ref<!hal.executable>, i32, i32, i32, i32, i64, i32 ..., tuple<i32, i32, !vm.ref<!hal.buffer>, i64, i64> ...)
vm.call.variadic @hal.command_buffer.dispatch(%ref_9616, %__device_7_executable_0__initializer_11_dispatch_0, %zero_13, %c64, %c512, %c1, %zero, [%c-1006632960, %c1], [(%zero_13, %zero_13, %ref_9498, %zero, %c67108864), (%zero_13, %zero_13, %ref_9614, %zero, %c8455716864)]) : (!vm.ref<!hal.command_buffer>, !vm.ref<!hal.executable>, i32, i32, i32, i32, i64, i32 ..., tuple<i32, i32, !vm.ref<!hal.buffer>, i64, i64> ...)
vm.call.variadic @hal.command_buffer.dispatch(%ref_9616, %__device_7_executable_0__initializer_11_dispatch_0, %zero_13, %c64, %c512, %c1, %zero, [%c-939524096, %c1], [(%zero_13, %zero_13, %ref_9507, %zero, %c67108864), (%zero_13, %zero_13, %ref_9614, %zero, %c8455716864)]) : (!vm.ref<!hal.command_buffer>, !vm.ref<!hal.executable>, i32, i32, i32, i32, i64, i32 ..., tuple<i32, i32, !vm.ref<!hal.buffer>, i64, i64> ...)
vm.call.variadic @hal.command_buffer.dispatch(%ref_9616, %__device_7_executable_0__initializer_11_dispatch_0, %zero_13, %c64, %c512, %c1, %zero, [%c-872415232, %c1], [(%zero_13, %zero_13, %ref_9516, %zero, %c67108864), (%zero_13, %zero_13, %ref_9614, %zero, %c8455716864)]) : (!vm.ref<!hal.command_buffer>, !vm.ref<!hal.executable>, i32, i32, i32, i32, i64, i32 ..., tuple<i32, i32, !vm.ref<!hal.buffer>, i64, i64> ...)
vm.call.variadic @hal.command_buffer.dispatch(%ref_9616, %__device_7_executable_0__initializer_11_dispatch_0, %zero_13, %c64, %c512, %c1, %zero, [%c-805306368, %c1], [(%zero_13, %zero_13, %ref_9525, %zero, %c67108864), (%zero_13, %zero_13, %ref_9614, %zero, %c8455716864)]) : (!vm.ref<!hal.command_buffer>, !vm.ref<!hal.executable>, i32, i32, i32, i32, i64, i32 ..., tuple<i32, i32, !vm.ref<!hal.buffer>, i64, i64> ...)
vm.call.variadic @hal.command_buffer.dispatch(%ref_9616, %__device_7_executable_0__initializer_11_dispatch_0, %zero_13, %c64, %c512, %c1, %zero, [%c-738197504, %c1], [(%zero_13, %zero_13, %ref_9534, %zero, %c67108864), (%zero_13, %zero_13, %ref_9614, %zero, %c8455716864)]) : (!vm.ref<!hal.command_buffer>, !vm.ref<!hal.executable>, i32, i32, i32, i32, i64, i32 ..., tuple<i32, i32, !vm.ref<!hal.buffer>, i64, i64> ...)
vm.call.variadic @hal.command_buffer.dispatch(%ref_9616, %__device_7_executable_0__initializer_11_dispatch_0, %zero_13, %c64, %c512, %c1, %zero, [%c-671088640, %c1], [(%zero_13, %zero_13, %ref_9543, %zero, %c67108864), (%zero_13, %zero_13, %ref_9614, %zero, %c8455716864)]) : (!vm.ref<!hal.command_buffer>, !vm.ref<!hal.executable>, i32, i32, i32, i32, i64, i32 ..., tuple<i32, i32, !vm.ref<!hal.buffer>, i64, i64> ...)
vm.call.variadic @hal.command_buffer.dispatch(%ref_9616, %__device_7_executable_0__initializer_11_dispatch_0, %zero_13, %c64, %c512, %c1, %zero, [%c-603979776, %c1], [(%zero_13, %zero_13, %ref_9552, %zero, %c67108864), (%zero_13, %zero_13, %ref_9614, %zero, %c8455716864)]) : (!vm.ref<!hal.command_buffer>, !vm.ref<!hal.executable>, i32, i32, i32, i32, i64, i32 ..., tuple<i32, i32, !vm.ref<!hal.buffer>, i64, i64> ...)
vm.call.variadic @hal.command_buffer.dispatch(%ref_9616, %__device_7_executable_0__initializer_11_dispatch_0, %zero_13, %c64, %c512, %c1, %zero, [%c-536870912, %c1], [(%zero_13, %zero_13, %ref_9561, %zero, %c67108864), (%zero_13, %zero_13, %ref_9614, %zero, %c8455716864)]) : (!vm.ref<!hal.command_buffer>, !vm.ref<!hal.executable>, i32, i32, i32, i32, i64, i32 ..., tuple<i32, i32, !vm.ref<!hal.buffer>, i64, i64> ...)
vm.call.variadic @hal.command_buffer.dispatch(%ref_9616, %__device_7_executable_0__initializer_11_dispatch_0, %zero_13, %c64, %c512, %c1, %zero, [%c-469762048, %c1], [(%zero_13, %zero_13, %ref_9570, %zero, %c67108864), (%zero_13, %zero_13, %ref_9614, %zero, %c8455716864)]) : (!vm.ref<!hal.command_buffer>, !vm.ref<!hal.executable>, i32, i32, i32, i32, i64, i32 ..., tuple<i32, i32, !vm.ref<!hal.buffer>, i64, i64> ...)
vm.call.variadic @hal.command_buffer.dispatch(%ref_9616, %__device_7_executable_0__initializer_11_dispatch_0, %zero_13, %c64, %c512, %c1, %zero, [%c-402653184, %c1], [(%zero_13, %zero_13, %ref_9579, %zero, %c67108864), (%zero_13, %zero_13, %ref_9614, %zero, %c8455716864)]) : (!vm.ref<!hal.command_buffer>, !vm.ref<!hal.executable>, i32, i32, i32, i32, i64, i32 ..., tuple<i32, i32, !vm.ref<!hal.buffer>, i64, i64> ...)
vm.call.variadic @hal.command_buffer.dispatch(%ref_9616, %__device_7_executable_0__initializer_11_dispatch_0, %zero_13, %c64, %c512, %c1, %zero, [%c-335544320, %c1], [(%zero_13, %zero_13, %ref_9588, %zero, %c67108864), (%zero_13, %zero_13, %ref_9614, %zero, %c8455716864)]) : (!vm.ref<!hal.command_buffer>, !vm.ref<!hal.executable>, i32, i32, i32, i32, i64, i32 ..., tuple<i32, i32, !vm.ref<!hal.buffer>, i64, i64> ...)
vm.call.variadic @hal.command_buffer.dispatch(%ref_9616, %__device_7_executable_0__initializer_11_dispatch_0, %zero_13, %c64, %c512, %c1, %zero, [%c-268435456, %c1], [(%zero_13, %zero_13, %ref_9597, %zero, %c67108864), (%zero_13, %zero_13, %ref_9614, %zero, %c8455716864)]) : (!vm.ref<!hal.command_buffer>, !vm.ref<!hal.executable>, i32, i32, i32, i32, i64, i32 ..., tuple<i32, i32, !vm.ref<!hal.buffer>, i64, i64> ...)
Also initialization time - that's probably bad unless it's initializing variables. Something those are touching is 8455716864 bytes, though, which is a lot.
Also:
vm.call @hal.device.queue.dealloca(%__device_1, %c-1, %null, %ref_9340, %ref_9336) : (!vm.ref<!hal.device>, i64, !vm.ref<!hal.fence>, !vm.ref<!hal.fence>, !vm.ref<!hal.buffer>) -> ()
%8884 = vm.call.variadic @hal.fence.await(%c-1_59, [%ref_9340]) : (i32, !vm.ref<!hal.fence> ...) -> i32
%ref_9341 = vm.call.variadic @hal.fence.join([%ref_9277, %ref_9276, %ref_9340, %ref_9339, %ref_9295, %ref_9294, %ref_9304, %ref_9303, %ref_9313, %ref_9312, %ref_9322, %ref_9321, %ref_9331, %ref_9330, %ref_9286, %ref_9285, %ref_100]) {nosideeffects} : (!vm.ref<!hal.fence> ...) -> !vm.ref<!hal.fence>
%ref_9342 = vm.call @hal.fence.create(%__device_2, %zero) : (!vm.ref<!hal.device>, i32) -> !vm.ref<!hal.fence>
%8885 = vm.call.variadic @hal.fence.await(%c-1_59, [%ref_9341]) : (i32, !vm.ref<!hal.fence> ...) -> i32
%ref_9343 = vm.call @hal.device.queue.alloca(%__device_2, %c-1, %null, %ref_9342, %zero, %c48, %c3075, %c149504_31) : (!vm.ref<!hal.device>, i64, !vm.ref<!hal.fence>, !vm.ref<!hal.fence>, i32, i32, i32, i64) -> !vm.ref<!hal.buffer>
%8886 = vm.call.variadic @hal.fence.await(%c-1_59, [%ref_9342]) : (i32, !vm.ref<!hal.fence> ...) -> i32
%ref_9344 = vm.call @hal.fence.create(%__device_2, %zero) : (!vm.ref<!hal.device>, i32) -> !vm.ref<!hal.fence>
%8887 = vm.call.variadic @hal.fence.await(%c-1_59, [%ref_9341]) : (i32, !vm.ref<!hal.fence> ...) -> i32
%ref_9345 = vm.call @hal.device.queue.alloca(%__device_2, %c-1, %null, %ref_9344, %zero, %c48, %c3075, %c917504_49) : (!vm.ref<!hal.device>, i64, !vm.ref<!hal.fence>, !vm.ref<!hal.fence>, i32, i32, i32, i64) -> !vm.ref<!hal.buffer>
%8888 = vm.call.variadic @hal.fence.await(%c-1_59, [%ref_9344]) : (i32, !vm.ref<!hal.fence> ...) -> i32
Those fence awaits are bad and probably being set by --iree-hip-legacy-sync=true
which is still the default. It shouldn't be. We need to remove that flag entirely.
I'd assumed we were hitting the i32 register max, but we're hitting the ref register max (which is a much lower 16383 - because no reasonable program should have 16k live ref counted values in the same function). Fixing executables and parameters will drop that dramatically (from ~1600+9100=~10k to a few dozen). Reusable command buffers/memoization should then make functions that only use a fraction of the total resources and take it down to a handful per function.
Ok, so need to find a way to thread the needle while we burn down the debt.
Looks like we can immediately try:
--iree-stream-resource-memory-model=discrete
(seems like that should give the biggest gain)--iree-hal-indirect-command-buffers=true --iree-hal-memoization=true
might help but confirming that it is a no go for HIP? :(Things that require some coding:
Both flags should work today with HIP - they're emulated, but way more efficient than not memoizing them. Each command buffer (stream execution region) gets outlined to its own function. You will have to remove that legacy sync flag (we should delete that - if things don't work with it off then those should be P0).
Legacy sync flag (still defaults to true for rocm :(): https://github.com/iree-org/iree/blob/main/compiler/plugins/target/ROCM/ROCMTarget.cpp#L103
Linking implemented in #18936 - it should help reduce the executable ref count from ~1600 to 8 (one for each device).
Without yet using your changes in #18936, I compiled by adding the following parameters
--iree-stream-resource-memory-model=discrete --iree-hip-legacy-sync=False
And got the following error when running the output:
iree/runtime/src/iree/vm/bytecode/archive.c:122: INVALID_ARGUMENT; FlatBuffer length prefix out of bounds (prefix is 1969516397 but only 70499378 available); loading bytecode module at '405b_f16_tp8.vmfb'; loading modules and dependencies; creating run context
That may indicate a corrupt VMFB.
Ah, I found my mistake. I thought --compile-to=vm
would produce the vmfb, but it seems it is one step before it. Tried compiled the rest of the way, but the compiler just seems to be crashing with this error:
LLVM ERROR: can't create Attribute 'mlir::iree_compiler::IREE::GPU::MMAIntrinsicAttr' because storage uniquer isn't initialized: the dialect was likely not loaded, or the attribute wasn't added with addAttributes<...>() in the Dialect::initialize() method.
The same error occurs when I run with all of the flags:
--iree-stream-resource-memory-model=discrete --iree-hip-legacy-sync=False --iree-hal-indirect-command-buffers=true --iree-hal-memoization=true
Would uploading a vmfb or any of the compilation steps be helpful?
vmfbs aren't useful, but the vm IR is (--compile-to=vm) as well as the stream IR (--compile-to=stream).
Intermediate MLIRs for compile command:
iree-build/tools/iree-compile ./405b_f16_tp8.mlir --iree-hal-target-device=hip[0] --iree-hal-target-device=hip[1] --iree-hal-target-device=hip[2] --iree-hal-target-device=hip[3] --iree-hal-target-device=hip[4] --iree-hal-target-device=hip[5] --iree-hal-target-device=hip[6] --iree-hal-target-device=hip[7] --iree-hal-executable-debug-level=3 --iree-hip-target=gfx942 -o 405b_f16_tp8.vmfb --iree-stream-resource-memory-model=discrete --iree-hip-legacy-sync=False --iree-hal-indirect-command-buffers=true --iree-hal-memoization=true
MLIRs:
Stream
vm
Cool - I'll compile down and see if I can spot anything and use #18936 (which should help a lot).
Still lots of suspicious things in here. Are we not fusing things with matmuls? Seeing batch_matmul dispatches followed by elementwise addition dispatches.
A large number of dispatches in here are slow_memcpy. There are literally 15128 slow_mempy dispatches. 🙅🙅🙅🙅🙅🙅🙅🙅
%18 = flow.dispatch.tensor.load %16, offsets = [0, 0, 0], sizes = [4, %15, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<4x?x128xf16>>{%15} -> tensor<4x?x128xf16>
flow.dispatch.tensor.store %18, %17, offsets = [0, 0, 0, 0], sizes = [4, %15, 1, 128], strides = [1, 1, 1, 1] : tensor<4x?x128xf16> -> !flow.dispatch.tensor<readwrite:tensor<4x?x1x128xf16>>{%15}
looks like most are just adding a unit dimension.
--iree-hal-indirect-command-buffers=true --iree-hal-memoization=true
(as those are on by default now): --iree-hal-force-indirect-command-buffers=true`. This should yield a program with a bunch of functions.can you also post flow? thx!
nice
I was fearing something like this:
%60411 = flow.tensor.splat %c79_i64 : tensor<4x2xi64>
%60412 = flow.tensor.transfer %60411 : tensor<4x2xi64> to #hal.device.affinity<@__device_0>
%60413 = flow.tensor.transfer %60411 : tensor<4x2xi64> to #hal.device.affinity<@__device_1>
%60414 = flow.tensor.transfer %60411 : tensor<4x2xi64> to #hal.device.affinity<@__device_2>
%60415 = flow.tensor.transfer %60411 : tensor<4x2xi64> to #hal.device.affinity<@__device_3>
%60416 = flow.tensor.transfer %60411 : tensor<4x2xi64> to #hal.device.affinity<@__device_4>
%60417 = flow.tensor.transfer %60411 : tensor<4x2xi64> to #hal.device.affinity<@__device_5>
%60418 = flow.tensor.transfer %60411 : tensor<4x2xi64> to #hal.device.affinity<@__device_6>
%60419 = flow.tensor.transfer %60411 : tensor<4x2xi64> to #hal.device.affinity<@__device_7>
Need to have a flow.tensor.transfer canonicalizer pattern to always fold away splats, but we may want something earlier than that.
this is leading to some of the badness we have with unfused slow memcpys:
%52801 = flow.tensor.splat %c69_i64 : tensor<4x2xi64>
%52802 = flow.tensor.transfer %52801 : tensor<4x2xi64> to #hal.device.affinity<@__device_0>
%52803 = flow.tensor.transfer %52801 : tensor<4x2xi64> to #hal.device.affinity<@__device_1>
%52804 = flow.tensor.transfer %52801 : tensor<4x2xi64> to #hal.device.affinity<@__device_2>
%52805 = flow.tensor.transfer %52801 : tensor<4x2xi64> to #hal.device.affinity<@__device_3>
%52806 = flow.tensor.transfer %52801 : tensor<4x2xi64> to #hal.device.affinity<@__device_4>
%52807 = flow.tensor.transfer %52801 : tensor<4x2xi64> to #hal.device.affinity<@__device_5>
%52808 = flow.tensor.transfer %52801 : tensor<4x2xi64> to #hal.device.affinity<@__device_6>
%52809 = flow.tensor.transfer %52801 : tensor<4x2xi64> to #hal.device.affinity<@__device_7>
%52810 = flow.tensor.reshape %52802 : tensor<4x2xi64> -> tensor<8xi64>
%52811 = flow.dispatch @decode_bs4$async_dispatch_108::@decode_bs4$async_dispatch_108_slow_memcpy(%52810, %259) : (tensor<8xi64>, tensor<8x4xi64>) -> %259
%52812 = flow.dispatch @decode_bs4$async_dispatch_109::@decode_bs4$async_dispatch_109_slow_memcpy(%256, %52811) : (tensor<8xi64>, tensor<8x4xi64>) -> %52811
%52813 = flow.dispatch @decode_bs4$async_dispatch_110::@decode_bs4$async_dispatch_110_slow_memcpy(%262, %52812) : (tensor<8xi64>, tensor<8x4xi64>) -> %52812
That %c69 should just be an i64 operand to those dispatches, and I bet those dispatches may never have been formed.
A flow.tensor.transfer canonicalizer on flow.tensor.splat would help later on, but we'd need to make sure it happened on the input prior to dispatch region formation too (maybe as part of global opt).
it'd also be useful to get the model exported from torch with @stellaraccident's assertions on shape dimensions
To speed up compilation we should probably get around to folding the flow.tensor.reshapes into the dispatch regions - when we have 200+k ops and ~70k of them are flow.tensor.reshape that mean nothing we could speed up some stages of the compiler here. May hurt deduplication, though. Hm.
Unelided timepoints create a lot of IR and a lot of runtime references that are long-lived. Found the issue and filed #18960. I ran with a 64k PVS capacity - it was slow but did mostly resolve all of the covered timepoints. There's still many remaining joins but that's because it's sharded 8 ways and we'd expect a lot of joins of 8 timepoints.
Before/after:
%12691 = stream.timepoint.join max(%12665, %12664, %12641, %12640, %12645, %12644, %12649, %12648, %12653, %12652, %12657, %12656, %12661, %12660, %12637, %12636, %12669, %12668, %12673, %12672, %12677, %12676, %12681, %12680, %12685, %12684, %12689, %12688, %1, %__auto.token_embd.weight__timepoint, %result_timepoint_12897, %result_timepoint_12899) => !stream.timepoint
>
%10812 = stream.timepoint.join max(%result_timepoint_12897, %result_timepoint_12899) => !stream.timepoint
%6153 = stream.resource.dealloca on(#hal.device.affinity<@__device_1>) await(%6152) => %result_5604 : !stream.resource<transient>{%444} => !stream.timepoint
%6154 = stream.timepoint.join max(%6125, %6124, %6153, %6152, %6133, %6132, %6137, %6136, %6141, %6140, %6145, %6144, %6149, %6148, %6129, %6128, %192) => !stream.timepoint
%result_5606, %result_timepoint_5607 = stream.resource.alloca uninitialized on(#hal.device.affinity<@__device_2>) await(%6154) => !stream.resource<transient>{%514} => !stream.timepoint
%result_5608, %result_timepoint_5609 = stream.resource.alloca uninitialized on(#hal.device.affinity<@__device_2>) await(%6154) => !stream.resource<transient>{%525} => !stream.timepoint
%6155 = stream.timepoint.join max(%6125, %6124, %6153, %6152, %6133, %6132, %6137, %6136, %6141, %6140, %6145, %6144, %6149, %6148, %6129, %6128, %192, %result_timepoint_5607, %result_timepoint_5609) => !stream.timepoint
%6156 = stream.cmd.execute on(#hal.device.affinity<@__device_2>) await(%6155) ...
>
%4995 = stream.resource.dealloca on(#hal.device.affinity<@__device_1>) await(%4994) => %result_5604 : !stream.resource<transient>{%444} => !stream.timepoint
%result_5606, %result_timepoint_5607 = stream.resource.alloca uninitialized on(#hal.device.affinity<@__device_2>) await(%4995) => !stream.resource<transient>{%514} => !stream.timepoint
%result_5608, %result_timepoint_5609 = stream.resource.alloca uninitialized on(#hal.device.affinity<@__device_2>) await(%4995) => !stream.resource<transient>{%525} => !stream.timepoint
%4996 = stream.timepoint.join max(%result_timepoint_5607, %result_timepoint_5609) => !stream.timepoint
%4997 = stream.cmd.execute on(#hal.device.affinity<@__device_2>) await(%4996) ...
So definitely a lot of room for improvement when that pass is rewritten.
If I understand correctly, this version might include shape assertions because it includes this patch during export. Original Flow Stream VM
Nice, that includes them!
(note that the workload ordinal has a bounded range there now - other dispatches now also have divisibility, e.g. %14<umin = 0, umax = 9007199254740991> : index
-> %10<umin = 16, umax = 131056, udiv = 16>
)
What happened?
I'm getting a register count overflow when trying to run llama3.1_405b_fp16 for 8 HIP devices targeting gfx942
iree/runtime/src/iree/vm/bytecode/verifier.c:345: RESOURCE_EXHAUSTED; register count overflow; loading bytecode module at '405b_f16_tp8.vmfb'; loading modules and dependencies; creating run context
Steps to reproduce your issue
Using
llama3.1_405b_f16_tp8.mlir
obtained from the SHARK-Platform model export process Compile the mlirAttempt to run the vmfb
What component(s) does this issue relate to?
Compiler
Version information
3b751a4d2
Additional context
MLIR vmfb The sharded parameter files are not necessary to replicate this error but I can upload these to a private location if necessary.