iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.6k stars 583 forks source link

VM bytecode serialization is slow for sub-byte types #15209

Open ScottTodd opened 12 months ago

ScottTodd commented 12 months ago

Following discussion on Discord here.

I see this code dominating performance profiles: https://github.com/openxla/iree/blob/1fa8b482c1000af56209339355f26f1feb978598/compiler/src/iree/compiler/Dialect/Util/IR/UtilAttrs.cpp#L284-L291

Repro instructions (could be reduced to a smaller program):

  1. Download https://storage.googleapis.com/shark_tank/llama_regression/llama2_7b_int4.mlir
  2. (Optional) convert to MLIR bytecode: python -m iree.compiler.tools.ir_tool copy --emit-bytecode [input.mlir] -o [output.mlirbc]
  3. Compile: iree-compile.exe --iree-hal-target-backends=llvm-cpu --iree-input-type=none --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 --iree-stream-resource-max-allocation-size=3221225472 --iree-llvmcpu-target-cpu-features=host llama2_7b_int4.mlirbc -o llama2_7b_int4_cpu.vmfb

For a ~6 minute compile, 4m40s is spent serializing image image

this show up in traces with sampling. Instrumentation shows a large empty region after most of the zones: image

ScottTodd commented 12 months ago

Here is a trace file with what I'm observing from the repro instructions, at commit https://github.com/openxla/iree/commit/193c13202aa28ea94254e8f87764244307c277a1: https://storage.googleapis.com/iree-shared-files/compiler_performance/compile_llama2_7b_int4_cpu_sampling_issue15209.tracy (note: 500MB, careful!)

qedawkins commented 12 months ago

This is related to MLIR's inability to store dense sub-byte resources. In fact, we're doing the reverse at the torch level, thereby temporarily multiplying the total amount of storage required for the resource in the compiler: https://github.com/llvm/torch-mlir/blob/f2c53b8ca5389fc63c38a66892b0d393718c3db4/lib/Dialect/TorchConversion/Transforms/UnpackQuantTensor.cpp#L107

One way to improve this would be by keeping it in some byte-aligned type and bitcast to an equivalent sub-byte tensor, but this requires changes in the frontend. We then would have to teach consteval to store sub-byte tensors in some packed format and bitcast.

benvanik commented 12 months ago

yeah, this meets up with the need for packed types ("this is 8xi4 in an i32" so we can treat the atomic storage unit as i32) - critical for things like i3 or i5 but has performance benefits for aligned types as well (4xi8 in i32, etc). If we know that something is logically i4 but physically i32 with 8 i4s then we can tell MLIR/etc its i32s and do the casting ourselves.

qedawkins commented 11 months ago

15260 should help. We could even add a pass to do this bitcasting ourselves, but I think that's basically just like moving serialization earlier so really we need changes in the frontend on these things.

ScottTodd commented 11 months ago

15260 should help. We could even add a pass to do this bitcasting ourselves, but I think that's basically just like moving serialization earlier so really we need changes in the frontend on these things.

Would you expect that PR on its own to improve compile time? I just tested and both baseline and that PR took ~37 minutes in translateModuleToBytecode (after just 1 minute for all the rest of the compilation... ouch). Same command line as in the original issue description. Not sure if 5 minutes -> 37 minutes was a regression or if something else in my environment changed.

image

qedawkins commented 11 months ago

15260 should help. We could even add a pass to do this bitcasting ourselves, but I think that's basically just like moving serialization earlier so really we need changes in the frontend on these things.

Would you expect that PR on its own to improve compile time? I just tested and both baseline and that PR took ~37 minutes in translateModuleToBytecode (after just 1 minute for all the rest of the compilation... ouch). Same command line as in the original issue description. Not sure if 5 minutes -> 37 minutes was a regression or if something else in my environment changed.

Ah sorry, should have clarified. We'll need to refresh the linalg from the frontend as well because the current linalg already has the weights as i4 (and hence unpacked + as APInt) so this PR won't have any effect. We could still add the pass I was talking about though (to insert bitcasts around all of the sub-byte constants and pre-serialize) so that this cost is frontloaded.

Edit: 5 minute to 37 minute regression is pretty spooky though... Not sure what recent changes would have caused that.