iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.59k stars 580 forks source link

[Stream] Transient buffer adding extra copies to Llama2 inference #16128

Open Max191 opened 8 months ago

Max191 commented 8 months ago

In the Llama2 model, the concatenation of the growing context is currently getting lowered into a copy to a transient buffer before copying into the global variable. The global_state tensor is originally stored as a single large 64x1x4095x32x128xf32 tensor, which we suspected may be blocking the copies from being elided. However, even with 64 separate 1x1x4095x32x128xf32 tensors, the transient buffer persists.

I have prepared a greatly reduced version of the Llama2 IR that runs into the same issue here: https://gist.github.com/Max191/0bcc276f634daf9341848783aca86cfc

Here is the compile command for the IR in the gist:

iree-compile \
    --iree-opt-const-eval=false \
    --iree-hal-target-backends=llvm-cpu \
    --iree-llvmcpu-target-cpu=znver4 \
    --iree-stream-resource-index-bits=64 \
    --iree-vm-target-index-bits=64 \
    --iree-llvmcpu-enable-ukernels=mmt4d \
    --iree-llvmcpu-link-embedded=false \
    --iree-llvmcpu-debug-symbols=true \
    --iree-global-opt-propagate-transposes \
    --iree-opt-aggressively-propagate-transposes \
    --iree-opt-outer-dim-concat \
    --compile-to=stream \
    --mlir-print-ir-after-all \
    --mlir-disable-threading \
    llama2_int4_reduced.mlir \
    -o llama2_int4_reduced_stream.mlir \
    2> dump_stream.mlir

And here is the IR before EmplaceAllocations for quick reference:

// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func @run_forward(%arg0: !hal.buffer_view) attributes {iree.abi.stub, iree.reflection = {iree.abi.declaration = "sync func @run_forward(%input0: tensor<1x1xi64>) -> ()"}} {
  %c524288 = arith.constant 524288 : index
  %c524288000 = arith.constant 524288000 : index
  %c2097152 = arith.constant 2097152 : index
  %c1_i32 = arith.constant 1 : i32
  %c268435520_i32 = arith.constant 268435520 : i32
  %c0_i64 = arith.constant 0 : i64
  %c1_i64 = arith.constant 1 : i64
  %c-1_i64 = arith.constant -1 : i64
  %c0 = arith.constant 0 : index
  %c1 = arith.constant 1 : index
  %c4096_i64 = arith.constant 4096 : i64
  %c4097_i64 = arith.constant 4097 : i64
  %c8 = arith.constant 8 : index
  %c16384 = arith.constant 16384 : index
  %c4 = arith.constant 4 : index
  %c8388608 = arith.constant 8388608 : index
  %c67092480 = arith.constant 67092480 : index
  %_global_state.1 = util.global.load @_global_state.1 : !stream.resource<variable>
  %_global_state.0 = util.global.load @_global_state.0 : !stream.resource<variable>
  %_global_seq_step.global = util.global.load @_global_seq_step.global : index
  %_constant = util.global.load @_constant : !stream.resource<constant>
  %_constant_0 = util.global.load @_constant_0 : !stream.resource<constant>
  %_params.model.embed_tokens.weight = util.global.load @_params.model.embed_tokens.weight : !stream.resource<constant>
  %_params.model.layers.0.input_layernorm.weight = util.global.load @_params.model.layers.0.input_layernorm.weight : !stream.resource<constant>
  %model.layers.0.self_attn.k_proj.weight.quant.scale = util.global.load @model.layers.0.self_attn.k_proj.weight.quant.scale : !stream.resource<constant>
  %model.layers.0.self_attn.k_proj.weight.quant.zero_point = util.global.load @model.layers.0.self_attn.k_proj.weight.quant.zero_point : !stream.resource<constant>
  %model.layers.0.self_attn.v_proj.weight.quant.scale = util.global.load @model.layers.0.self_attn.v_proj.weight.quant.scale : !stream.resource<constant>
  %model.layers.0.self_attn.v_proj.weight.quant.zero_point = util.global.load @model.layers.0.self_attn.v_proj.weight.quant.zero_point : !stream.resource<constant>
  %model.layers.0.self_attn.k_proj.weight = util.global.load @model.layers.0.self_attn.k_proj.weight : !stream.resource<constant>
  %model.layers.0.self_attn.v_proj.weight = util.global.load @model.layers.0.self_attn.v_proj.weight : !stream.resource<constant>
  %0 = stream.async.transfer %_global_state.1 : !stream.resource<variable>{%c67092480} -> !stream.resource<*>{%c67092480}
  %1 = stream.async.transfer %_global_state.0 : !stream.resource<variable>{%c67092480} -> !stream.resource<*>{%c67092480}
  %2 = stream.async.transfer %_constant : !stream.resource<constant>{%c2097152} -> !stream.resource<*>{%c2097152}
  %3 = stream.async.transfer %_constant_0 : !stream.resource<constant>{%c2097152} -> !stream.resource<*>{%c2097152}
  %4 = stream.async.transfer %_params.model.embed_tokens.weight : !stream.resource<constant>{%c524288000} -> !stream.resource<*>{%c524288000}
  %5 = stream.async.transfer %_params.model.layers.0.input_layernorm.weight : !stream.resource<constant>{%c16384} -> !stream.resource<*>{%c16384}
  %6 = stream.async.transfer %model.layers.0.self_attn.k_proj.weight.quant.scale : !stream.resource<constant>{%c524288} -> !stream.resource<*>{%c524288}
  %7 = stream.async.transfer %model.layers.0.self_attn.k_proj.weight.quant.zero_point : !stream.resource<constant>{%c524288} -> !stream.resource<*>{%c524288}
  %8 = stream.async.transfer %model.layers.0.self_attn.v_proj.weight.quant.scale : !stream.resource<constant>{%c524288} -> !stream.resource<*>{%c524288}
  %9 = stream.async.transfer %model.layers.0.self_attn.v_proj.weight.quant.zero_point : !stream.resource<constant>{%c524288} -> !stream.resource<*>{%c524288}
  %10 = stream.async.transfer %model.layers.0.self_attn.k_proj.weight : !stream.resource<constant>{%c8388608} -> !stream.resource<*>{%c8388608}
  %11 = stream.async.transfer %model.layers.0.self_attn.v_proj.weight : !stream.resource<constant>{%c8388608} -> !stream.resource<*>{%c8388608}
  hal.buffer_view.assert<%arg0 : !hal.buffer_view> message("input0") shape([%c1, %c1]) type(%c268435520_i32) encoding(%c1_i32)
  %12 = stream.tensor.import %arg0 : !hal.buffer_view -> tensor<1x1xi64> in !stream.resource<external>{%c8}
  %13 = stream.async.transfer %12 : !stream.resource<external>{%c8} -> !stream.resource<*>{%c8}
  %14 = arith.muli %_global_seq_step.global, %c16384 : index
  %15 = arith.index_cast %_global_seq_step.global : index to i64
  %16 = stream.async.dispatch @run_forward_dispatch_0::@run_forward_dispatch_0_generic_4096_i64xf32(%4[%c0 to %c524288000 for %c524288000], %13[%c0 to %c8 for %c8]) : (!stream.resource<*>{%c524288000}, !stream.resource<*>{%c8}) -> !stream.resource<*>{%c16384}
  %17 = stream.async.dispatch @run_forward_dispatch_1::@run_forward_dispatch_1_generic_4096_f32(%16[%c0 to %c16384 for %c16384]) : (!stream.resource<*>{%c16384}) -> !stream.resource<*>{%c4}
  %18 = stream.async.dispatch @run_forward_dispatch_2::@run_forward_dispatch_2_generic_4096_f32(%5[%c0 to %c16384 for %c16384], %16[%c0 to %c16384 for %c16384], %17[%c0 to %c4 for %c4]) : (!stream.resource<*>{%c16384}, !stream.resource<*>{%c16384}, !stream.resource<*>{%c4}) -> !stream.resource<*>{%c16384}
  %19 = stream.async.dispatch @run_forward_dispatch_3::@run_forward_dispatch_3_generic_4096x32x128_f32(%10[%c0 to %c8388608 for %c8388608], %6[%c0 to %c524288 for %c524288], %7[%c0 to %c524288 for %c524288], %18[%c0 to %c16384 for %c16384]) : (!stream.resource<*>{%c8388608}, !stream.resource<*>{%c524288}, !stream.resource<*>{%c524288}, !stream.resource<*>{%c16384}) -> !stream.resource<*>{%c16384}
  %20 = stream.async.dispatch @run_forward_dispatch_3::@run_forward_dispatch_3_generic_4096x32x128_f32(%11[%c0 to %c8388608 for %c8388608], %8[%c0 to %c524288 for %c524288], %9[%c0 to %c524288 for %c524288], %18[%c0 to %c16384 for %c16384]) : (!stream.resource<*>{%c8388608}, !stream.resource<*>{%c524288}, !stream.resource<*>{%c524288}, !stream.resource<*>{%c16384}) -> !stream.resource<*>{%c16384}
  %21 = arith.addi %15, %c1_i64 : i64
  %22 = arith.addi %15, %c4097_i64 : i64
  %23 = arith.cmpi sge, %21, %c0_i64 : i64
  %24 = arith.select %23, %21, %22 : i64
  %25 = arith.cmpi slt, %24, %c0_i64 : i64
  %26 = arith.select %25, %c0_i64, %24 : i64
  %27 = arith.cmpi sgt, %26, %c4096_i64 : i64
  %28 = arith.select %27, %c4096_i64, %26 : i64
  %29 = arith.index_cast %28 : i64 to index
  %30 = arith.cmpi sge, %29, %c0 : index
  %31 = arith.select %30, %29, %c0 : index
  %32 = stream.async.dispatch @run_forward_dispatch_5::@run_forward_dispatch_5_generic_4096_f32[%31, %_global_seq_step.global](%3[%c0 to %c2097152 for %c2097152], %31, %2[%c0 to %c2097152 for %c2097152], %_global_seq_step.global, %19[%c0 to %c16384 for %c16384], %19[%c0 to %c16384 for %c16384]) : (!stream.resource<*>{%c2097152}, index, !stream.resource<*>{%c2097152}, index, !stream.resource<*>{%c16384}, !stream.resource<*>{%c16384}) -> !stream.resource<*>{%c16384}
  %33 = affine.apply affine_map<()[s0] -> (s0 + 1)>()[%_global_seq_step.global]
  %34 = arith.muli %33, %c16384 : index
  %35 = stream.async.alloca : !stream.resource<*>{%34}
  %36 = stream.async.copy %1[%c0 to %14], %35[%c0 to %14], %14 : !stream.resource<*>{%c67092480} -> %35 as !stream.resource<*>{%34}
  %37 = arith.addi %14, %c16384 : index
  %38 = stream.async.update %32, %36[%14 to %37] : !stream.resource<*>{%c16384} -> %36 as !stream.resource<*>{%34}
  %39 = stream.async.alloca : !stream.resource<*>{%34}
  %40 = stream.async.copy %0[%c0 to %14], %39[%c0 to %14], %14 : !stream.resource<*>{%c67092480} -> %39 as !stream.resource<*>{%34}
  %41 = stream.async.update %20, %40[%14 to %37] : !stream.resource<*>{%c16384} -> %40 as !stream.resource<*>{%34}
  %42 = arith.index_cast %33 : index to i64
  %43 = arith.addi %42, %c-1_i64 : i64
  %44 = arith.cmpi slt, %43, %c0_i64 : i64
  %45 = arith.select %44, %c0_i64, %43 : i64
  %46 = arith.cmpi sgt, %45, %42 : i64
  %47 = arith.select %46, %42, %45 : i64
  %48 = arith.index_cast %47 : i64 to index
  %49 = arith.muli %48, %c16384 : index
  %50 = arith.addi %49, %c16384 : index
  %51 = stream.async.copy %38[%49 to %50], %1[%14 to %37], %c16384 : !stream.resource<*>{%34} -> %1 as !stream.resource<*>{%c67092480}
  %52 = stream.async.copy %41[%49 to %50], %0[%14 to %37], %c16384 : !stream.resource<*>{%34} -> %0 as !stream.resource<*>{%c67092480}
  %53 = arith.addi %_global_seq_step.global, %c1 : index
  %54 = stream.async.transfer %51 : !stream.resource<*>{%c67092480} -> !stream.resource<variable>{%c67092480}
  %55 = stream.async.transfer %52 : !stream.resource<*>{%c67092480} -> !stream.resource<variable>{%c67092480}
  util.global.store %53, @_global_seq_step.global : index
  util.global.store %54, @_global_state.0 : !stream.resource<variable>
  util.global.store %55, @_global_state.1 : !stream.resource<variable>
  return
}
benvanik commented 8 months ago

nice repro! this is what I was expecting, but should be possible to fix. I'll take a look.

benvanik commented 8 months ago

May actually be tricky (why it hasn't been solved yet!), but I'd like to see a trace of what you're running with this. The copies now should all be concurrent and take up much less wall-time than before (still bad, but should be 1/64th as bad). It's never going to be possible in a full model using dynamic shapes to ensure there are no dynamically shaped transients that need an allocation, so if the time spent is mostly in the allocation with this change (instead of the copies) solving the placement won't really help.

Specifically, this chain is near impossible to eliminate once formed:

    %35 = flow.tensor.empty : tensor<1x32x?x128xf32>{%34}
    %36 = flow.dispatch @run_forward_dispatch_8::@run_forward_dispatch_8_slow_memcpy[%_global_seq_step.global, %34](%6, %35, %_global_seq_step.global, %34) : (tensor<32x?x128xf32>{%_global_seq_step.global}, tensor<1x32x?x128xf32>{%34}, index, index) -> %35{%34}
    %37 = flow.dispatch @run_forward_dispatch_9::@run_forward_dispatch_9_slow_memcpy[%_global_seq_step.global, %34](%33, %36, %_global_seq_step.global, %34) : (tensor<32x128xf32>, tensor<1x32x?x128xf32>{%34}, index, index) -> %36{%34}
    %38 = flow.dispatch @run_forward_dispatch_10::@run_forward_dispatch_10_slow_memcpy[%_global_seq_step.global, %34](%8, %35, %_global_seq_step.global, %34) : (tensor<32x?x128xf32>{%_global_seq_step.global}, tensor<1x32x?x128xf32>{%34}, index, index) -> %35{%34}
    %39 = flow.dispatch @run_forward_dispatch_11::@run_forward_dispatch_11_slow_memcpy[%_global_seq_step.global, %34](%19, %38, %_global_seq_step.global, %34) : (tensor<32x128xf32>, tensor<1x32x?x128xf32>{%34}, index, index) -> %38{%34}
    %40 = flow.dispatch @run_forward_dispatch_12::@run_forward_dispatch_12_slow_memcpy[%_global_seq_step.global, %34](%_global_seq_step.global, %37, %34) : (index, tensor<1x32x?x128xf32>{%34}, index) -> tensor<32x128xf32>
    %41 = flow.dispatch @run_forward_dispatch_13::@run_forward_dispatch_13_slow_memcpy[%_global_seq_step.global, %34](%_global_seq_step.global, %39, %34) : (index, tensor<1x32x?x128xf32>{%34}, index) -> tensor<32x128xf32>
    %42 = flow.tensor.reshape %40 : tensor<32x128xf32> -> tensor<1x1x1x32x128xf32>
    %43 = flow.tensor.update %42, %_global_state.0[%c0, %c0, %_global_seq_step.global, %c0, %c0] : tensor<1x1x1x32x128xf32> -> %_global_state.0 as tensor<1x1x4095x32x128xf32>
    %44 = flow.tensor.reshape %41 : tensor<32x128xf32> -> tensor<1x1x1x32x128xf32>
    %45 = flow.tensor.update %44, %_global_state.1[%c0, %c0, %_global_seq_step.global, %c0, %c0] : tensor<1x1x1x32x128xf32> -> %_global_state.1 as tensor<1x1x4095x32x128xf32>

so long as those slow in-place memcpy operations are there (from tensor insert/extract slices it seems) you'll have dynamic allocations. We can make dispatch_12 and dispatch_13 go in-place into the variables but all the others need to not be dispatches for us to do anything.

benvanik commented 8 months ago

(IOW: I suspect your wall-time with this change is mostly in the alloc/free of the transient memory, and so long as you have any dispatch anywhere in the program that produces a result based on a dynamic size you'll have the alloc/free that won't hit the caching allocator, and if you need things faster then we may need to implement the block-based suballocator instead of trying to squeeze more out of placement for now)

benvanik commented 8 months ago

The slow memcpy dispatches look like they come from a tensor.concat + tensor.extract_slice - I can't quite tell what it's doing, but it seems like the copy is required? Definitely want someone to look at this higher up and ensure we aren't mixing a bunch of data together and creating these dependencies when it's not required. In the fullness of time we can make several of these things better but so long as we end up with dispatches acting opaquely on blobs of data - especially in-place ones - we can't do anything safely.

  %_global_state.0 = util.global.load @_global_state.0 : tensor<1x1x4095x32x128xf32>
  %1 = flow.tensor.slice %_global_state.0[%c0, %c0, %c0, %c0, %c0 for %c1, %c1, %_global_seq_step.global, %c32, %c128] : tensor<1x1x4095x32x128xf32> -> tensor<1x1x?x32x128xf32>{%_global_seq_step.global}
  %2 = flow.tensor.reshape %1 : tensor<1x1x?x32x128xf32>{%_global_seq_step.global} -> tensor<1x?x32x128xf32>{%_global_seq_step.global}
  %5 = tensor.empty(%_global_seq_step.global) : tensor<1x32x?x128xf32>
  %6 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%2 : tensor<1x?x32x128xf32>) outs(%5 : tensor<1x32x?x128xf32>) {
  ^bb0(%in: f32, %out: f32):
    linalg.yield %in : f32
  } -> tensor<1x32x?x128xf32>
  ...
  %54 = linalg.generic ... -> tensor<1x32x1x128xf32>
  // ---------------------
  // <<WHAT IS THIS??>>
  // this will always require dynamic allocations unless we can fold the concat+extract_slice away
  %concat = tensor.concat dim(2) %6, %54 : (tensor<1x32x?x128xf32>, tensor<1x32x1x128xf32>) -> tensor<1x32x?x128xf32>
  %extracted_slice_19 = tensor.extract_slice %concat[0, 0, %63, 0] [1, 32, 1, 128] [1, 1, 1, 1] : tensor<1x32x?x128xf32> to tensor<1x32x1x128xf32>
  %64 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%extracted_slice_19 : tensor<1x32x1x128xf32>) outs(%57 : tensor<1x1x32x128xf32>) {
  ^bb0(%in: f32, %out: f32):
    linalg.yield %in : f32
  } -> tensor<1x1x32x128xf32>
  // ---------------------
  %66 = flow.tensor.reshape %64 : tensor<1x1x32x128xf32> -> tensor<1x1x1x32x128xf32>
  %67 = flow.tensor.update %66, %_global_state.0_22[%c0, %c0, %_global_seq_step.global_21, %c0, %c0] : tensor<1x1x1x32x128xf32> -> %_global_state.0_22 as tensor<1x1x4095x32x128xf32>
  util.global.store %67, @_global_state.0 : tensor<1x1x4095x32x128xf32>
qedawkins commented 8 months ago

The slow memcpy dispatches look like they come from a tensor.concat + tensor.extract_slice - I can't quite tell what it's doing, but it seems like the copy is required? Definitely want someone to look at this higher up and ensure we aren't mixing a bunch of data together and creating these dependencies when it's not required. In the fullness of time we can make several of these things better but so long as we end up with dispatches acting opaquely on blobs of data - especially in-place ones - we can't do anything safely.

  %_global_state.0 = util.global.load @_global_state.0 : tensor<1x1x4095x32x128xf32>
  %1 = flow.tensor.slice %_global_state.0[%c0, %c0, %c0, %c0, %c0 for %c1, %c1, %_global_seq_step.global, %c32, %c128] : tensor<1x1x4095x32x128xf32> -> tensor<1x1x?x32x128xf32>{%_global_seq_step.global}
  %2 = flow.tensor.reshape %1 : tensor<1x1x?x32x128xf32>{%_global_seq_step.global} -> tensor<1x?x32x128xf32>{%_global_seq_step.global}
  %5 = tensor.empty(%_global_seq_step.global) : tensor<1x32x?x128xf32>
  %6 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%2 : tensor<1x?x32x128xf32>) outs(%5 : tensor<1x32x?x128xf32>) {
  ^bb0(%in: f32, %out: f32):
    linalg.yield %in : f32
  } -> tensor<1x32x?x128xf32>
  ...
  %54 = linalg.generic ... -> tensor<1x32x1x128xf32>
  // ---------------------
  // <<WHAT IS THIS??>>
  // this will always require dynamic allocations unless we can fold the concat+extract_slice away
  %concat = tensor.concat dim(2) %6, %54 : (tensor<1x32x?x128xf32>, tensor<1x32x1x128xf32>) -> tensor<1x32x?x128xf32>
  %extracted_slice_19 = tensor.extract_slice %concat[0, 0, %63, 0] [1, 32, 1, 128] [1, 1, 1, 1] : tensor<1x32x?x128xf32> to tensor<1x32x1x128xf32>
  %64 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%extracted_slice_19 : tensor<1x32x1x128xf32>) outs(%57 : tensor<1x1x32x128xf32>) {
  ^bb0(%in: f32, %out: f32):
    linalg.yield %in : f32
  } -> tensor<1x1x32x128xf32>
  // ---------------------
  %66 = flow.tensor.reshape %64 : tensor<1x1x32x128xf32> -> tensor<1x1x1x32x128xf32>
  %67 = flow.tensor.update %66, %_global_state.0_22[%c0, %c0, %_global_seq_step.global_21, %c0, %c0] : tensor<1x1x1x32x128xf32> -> %_global_state.0_22 as tensor<1x1x4095x32x128xf32>
  util.global.store %67, @_global_state.0 : tensor<1x1x4095x32x128xf32>

Are you compiling with --iree-global-opt-propagate-transposes and --iree-opt-outer-dim-concat? That should fold the concat and extract_slice. Otherwise if IIRC there is some pytorch related reasons we end up generating IR like that.

benvanik commented 8 months ago

not using flags besides const-eval disable. I'm not liking that such flags exist - any reason those aren't on by default? (and also named differently? shouldn't they all be global-opt?)

qedawkins commented 8 months ago

not using flags besides const-eval disable. I'm not liking that such flags exist - any reason those aren't on by default? (and also named differently? shouldn't they all be global-opt?)

I thought I renamed --iree-opt-outer-dim-concat actually... I can do that in a bit, but to summarize

--iree-opt-outer-dim-concat: I needed to land the change to the frontend that created concats after adding support in IREE, and I added this flag at the same time, but to avoid performance regressions on a torch-mlir integrate I left it off by default. We can try turning it on by default now I think, it was more of a phase ordering thing.

--iree-global-opt-propagate-transposes: There were some regressions, but fixes should be landing soon (tracking issue #15973)

(#16059 is to turn the latter back on, needs some other fixes it looks like)

benvanik commented 8 months ago

Cool - any flag we can remove makes things much better!

I ran with that and the copies at the start go away but the ones later on still exist. It's still not clear that we can reliably eliminate them as it's mixing data:

    %27 = flow.dispatch @run_forward_dispatch_5::@run_forward_dispatch_5_generic_4096_f32[%25, %_global_seq_step.global](%_constant_0, %25, %_constant, %_global_seq_step.global, %26, %12) : (tensor<4096x128xf32>, index, tensor<4096x128xf32>, index, tensor<1x32x1x2x64xf32>, tensor<4096xf32>) -> tensor<4096xf32>
    %28 = affine.apply #map7()[%_global_seq_step.global]
    %29 = flow.tensor.empty : tensor<1x?x32x128xf32>{%28}
    %30 = flow.tensor.update %2, %29[%c0, %c0, %c0, %c0] : tensor<1x?x32x128xf32>{%_global_seq_step.global} -> %29 as tensor<1x?x32x128xf32>{%28}
    %31 = flow.tensor.reshape %27 : tensor<4096xf32> -> tensor<1x1x32x128xf32>
    %32 = flow.tensor.update %31, %30[%c0, %_global_seq_step.global, %c0, %c0] : tensor<1x1x32x128xf32> -> %30 as tensor<1x?x32x128xf32>{%28}
    %33 = flow.tensor.reshape %32 : tensor<1x?x32x128xf32>{%28} -> tensor<?x1x32x128xf32>{%28}
    %45 = flow.tensor.slice %33[%44, %c0, %c0, %c0 for %c1, %c1, %c32, %c128] : tensor<?x1x32x128xf32>{%28} -> tensor<1x1x32x128xf32>
    %47 = flow.tensor.reshape %45 : tensor<1x1x32x128xf32> -> tensor<1x1x1x32x128xf32>
    %48 = flow.tensor.update %47, %_global_state.0[%c0, %c0, %_global_seq_step.global, %c0, %c0] : tensor<1x1x1x32x128xf32> -> %_global_state.0 as tensor<1x1x4095x32x128xf32>
benvanik commented 8 months ago

(I'm looking for a non-inplace flow.dispatch -> flow.tensor.update of the global OR a single flow.dispatch with in-place into the global, anything else in-between makes it much harder)

qedawkins commented 8 months ago

not using flags besides const-eval disable. I'm not liking that such flags exist - any reason those aren't on by default? (and also named differently? shouldn't they all be global-opt?)

Ok actually the reason they have different names is because the naming scheme we chose for compiler flags and internal flags is different, i.e. here are the GlobalOptimization compiler flags: https://github.com/openxla/iree/blob/dd176120bf544646f7c436c3bcb30d68aea721af/compiler/src/iree/compiler/Pipelines/Options.cpp#L116

And here are the internal GlobalOptimization flags: https://github.com/openxla/iree/blob/main/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp

We could do a mass flag renaming for the compiler ones. For context on why one is a compiler flag and one is internal, I didn't actually want --iree-global-opt-propagate-transposes to ever be a flag because it should always be worth having on. --iree-opt-outer-dim-concat is less clear though, as there are cases where we wouldn't want to do this so I exposed the flag all the way. Probably needs more workloads to solidify.

Edit: sorry, getting a bit off-topic. Just some background on the flags. Can move discussion elsewhere if it's worth continuing.

benvanik commented 8 months ago

Mostly just has a big smell to it - any flag that is required to make a model perform well is effectively magic, and every use of the flag is a bug on our side as it means the user had to be aware of the flag, do the deep technical work to know they need it, and actually apply it, and then communicate that it's load-bearing in any comms they have with us. As a general rule we try never to add flags that survive past experimentation/integrations/etc for those reasons. If it's critical enough to exist and will exist forever then it should be an attribute on the IR instead.

qedawkins commented 8 months ago

Mostly just has a big smell to it - any flag that is required to make a model perform well is effectively magic, and every use of the flag is a bug on our side as it means the user had to be aware of the flag, do the deep technical work to know they need it, and actually apply it, and then communicate that it's load-bearing in any comms they have with us. As a general rule we try never to add flags that survive past experimentation/integrations/etc for those reasons. If it's critical enough to exist and will exist forever then it should be an attribute on the IR instead.

+1, and I'd like to make the above flag set the default, just takes some work to get there. Also it would be good to do a refactoring of all the compiler flags (i.e. ones currently exposed with GlobalOptimizationOptions/StreamOptions/FlowOptions/whatever) to attributes, but someone needs to find time for it.

qedawkins commented 8 months ago

Cool - any flag we can remove makes things much better!

I ran with that and the copies at the start go away but the ones later on still exist. It's still not clear that we can reliably eliminate them as it's mixing data:

    %27 = flow.dispatch @run_forward_dispatch_5::@run_forward_dispatch_5_generic_4096_f32[%25, %_global_seq_step.global](%_constant_0, %25, %_constant, %_global_seq_step.global, %26, %12) : (tensor<4096x128xf32>, index, tensor<4096x128xf32>, index, tensor<1x32x1x2x64xf32>, tensor<4096xf32>) -> tensor<4096xf32>
    %28 = affine.apply #map7()[%_global_seq_step.global]
    %29 = flow.tensor.empty : tensor<1x?x32x128xf32>{%28}
    %30 = flow.tensor.update %2, %29[%c0, %c0, %c0, %c0] : tensor<1x?x32x128xf32>{%_global_seq_step.global} -> %29 as tensor<1x?x32x128xf32>{%28}
    %31 = flow.tensor.reshape %27 : tensor<4096xf32> -> tensor<1x1x32x128xf32>
    %32 = flow.tensor.update %31, %30[%c0, %_global_seq_step.global, %c0, %c0] : tensor<1x1x32x128xf32> -> %30 as tensor<1x?x32x128xf32>{%28}
    %33 = flow.tensor.reshape %32 : tensor<1x?x32x128xf32>{%28} -> tensor<?x1x32x128xf32>{%28}
    %45 = flow.tensor.slice %33[%44, %c0, %c0, %c0 for %c1, %c1, %c32, %c128] : tensor<?x1x32x128xf32>{%28} -> tensor<1x1x32x128xf32>
    %47 = flow.tensor.reshape %45 : tensor<1x1x32x128xf32> -> tensor<1x1x1x32x128xf32>
    %48 = flow.tensor.update %47, %_global_state.0[%c0, %c0, %_global_seq_step.global, %c0, %c0] : tensor<1x1x1x32x128xf32> -> %_global_state.0 as tensor<1x1x4095x32x128xf32>

In response to this, this looks quite similar to the construction of the global updates here: https://github.com/nod-ai/SHARK-Turbine/blob/main/examples/llama2_inference/llama2.ipynb (Not sure how to link to specific lines in an ipynb)

def slice_up_to_step(global_pkv, seq_step, heads, hidden_dim):
    """
    global_pkv: the global pkv tensor
    seq_step: the current token index of the model
    heads: the number of attn heads
    hidden_dim: feature dimension size
    takes the global_pkv tensor and gets the seq_step pair for each head
    """
    all_pkv_tensors = []
    for i in range(heads * 2):
        sliced = IREE.tensor_slice(
            global_pkv, i, 0, (0, heads), (0, seq_step), (0, hidden_dim)
        )  # sequence context dim
        all_pkv_tensors.append(
            IREE.tensor_reshape(sliced, 1, heads, seq_step, hidden_dim)
        )

    return all_pkv_tensors

and


def update_state(global_pkv, state_updates, seq_step, heads, hidden_dim):
    """
    global_pkv: the global pkv tensor
    state_updates: the state updates output by a forward pass of the model
    seq_step: the current token index of the model
    heads: the number of attn heads
    hidden_dim: feature dimension size
    updates the global state of the model at seq_step with state_updates
    """
    all_updates = []
    for i in range(heads * 2):
        #expand dim in state updates to match the rank of global_pkv
        update = IREE.tensor_reshape(
            state_updates[i], 1, 1, heads, 1, hidden_dim
        )
        all_updates.append(
            IREE.tensor_update(global_pkv, update, i, 0, 0, seq_step, 0)
        )
    return all_updates

I was expecting splitting the context into 64 globals to simplify the logic here but looks like that's maybe not the case.

benvanik commented 8 months ago

if it's something we can improve propagation for (at aten/torch/whatever or linalg/tensor) then things will be fine - or new canonicalization patterns/folders on the flow.tensor.* ops would be awesome. I am not sure of what it's doing though so not confident we can fold them away. would definitely be worthwhile if we could.

qedawkins commented 8 months ago

It looks like we could add a folding pattern for reshape -> tensor.update pretty easily

%73 = flow.tensor.reshape %65 : tensor<1x1x32x128xf32> -> tensor<1x1x1x32x128xf32>
%74 = flow.tensor.update %73, %_global_state.0_23[%c0, %c0, %_global_seq_step.global_22, %c0, %c0] : tensor<1x1x1x32x128xf32> -> %_global_state.0_23 as tensor<1x1x4095x32x128xf32>

That should help get rid of some of the noise.

Same with tensor.reshape like ops next to flow ops for when flow ops are a part of the model input.

qedawkins commented 8 months ago

A few notes:


There are a lot of repeated util.global.load ops, i.e.

%_global_seq_step.global_33 = util.global.load @_global_seq_step.global : index
%_global_seq_step.global_35 = util.global.load @_global_seq_step.global : index
%_global_seq_step.global_37 = util.global.load @_global_seq_step.global : index

We could probably run SimplifyGlobalAccesses around the Flow level to simplify some of these (might not matter but cleanup is always good). Needs #16002 for it to be correct in conjunction with SCF though.


This dynamic index arithmetic looks like it might be blocking some potential simplification/canonicalization

  %75 = affine.apply affine_map<()[s0] -> (s0 + 1)>()[%_global_seq_step.global]
  %76 = arith.index_cast %75 : index to i64
  %77 = arith.addi %76, %c-1_i64 : i64
  %78 = arith.cmpi slt, %77, %c0_i64 : i64
  %79 = arith.select %78, %c0_i64, %77 : i64
  %80 = arith.cmpi sgt, %79, %76 : i64
  %81 = arith.select %80, %76, %79 : i64
  %82 = arith.index_cast %81 : i64 to index

This is just

int64_t step = step_index + 1;
int64_t sm1 = step - 1;
int64_t clamp_0 = sm1 < 0 ? 0 : sm1;
int64_t clamp_high = clamp_0 > sm1 ? step : clamp_0;

Which simplifies to (unless I made a mistake)

int64_t clamp_high = step_index < 0 ? step_index + 1 : step_index;

and if we further assert step_index > 0 (which it is for our model) this all goes away and just becomes step_index. I'd guess there is a torch lowering not obeying torch.assume_symbolic_shapes or something that is causing this (also still not really sure what it's supposed to mean anyway).

benvanik commented 8 months ago

nice! simplifying the math would definitely help, as that kind of sequence throws off any kind of range analysis we'd do to know that updates were non-overlapping

qedawkins commented 8 months ago

(on a related note, it helps to reduce on/share Torch IR directly for stuff like this. It's hard to know where this might be coming from and where to make a change starting from Linalg IR; @Max191 if possible can you share the Torch IR when you get a chance?)

Max191 commented 8 months ago

May actually be tricky (why it hasn't been solved yet!), but I'd like to see a trace of what you're running with this. The copies now should all be concurrent and take up much less wall-time than before (still bad, but should be 1/64th as bad). It's never going to be possible in a full model using dynamic shapes to ensure there are no dynamically shaped transients that need an allocation, so if the time spent is mostly in the allocation with this change (instead of the copies) solving the placement won't really help.

Specifically, this chain is near impossible to eliminate once formed:

    %35 = flow.tensor.empty : tensor<1x32x?x128xf32>{%34}
    %36 = flow.dispatch @run_forward_dispatch_8::@run_forward_dispatch_8_slow_memcpy[%_global_seq_step.global, %34](%6, %35, %_global_seq_step.global, %34) : (tensor<32x?x128xf32>{%_global_seq_step.global}, tensor<1x32x?x128xf32>{%34}, index, index) -> %35{%34}
    %37 = flow.dispatch @run_forward_dispatch_9::@run_forward_dispatch_9_slow_memcpy[%_global_seq_step.global, %34](%33, %36, %_global_seq_step.global, %34) : (tensor<32x128xf32>, tensor<1x32x?x128xf32>{%34}, index, index) -> %36{%34}
    %38 = flow.dispatch @run_forward_dispatch_10::@run_forward_dispatch_10_slow_memcpy[%_global_seq_step.global, %34](%8, %35, %_global_seq_step.global, %34) : (tensor<32x?x128xf32>{%_global_seq_step.global}, tensor<1x32x?x128xf32>{%34}, index, index) -> %35{%34}
    %39 = flow.dispatch @run_forward_dispatch_11::@run_forward_dispatch_11_slow_memcpy[%_global_seq_step.global, %34](%19, %38, %_global_seq_step.global, %34) : (tensor<32x128xf32>, tensor<1x32x?x128xf32>{%34}, index, index) -> %38{%34}
    %40 = flow.dispatch @run_forward_dispatch_12::@run_forward_dispatch_12_slow_memcpy[%_global_seq_step.global, %34](%_global_seq_step.global, %37, %34) : (index, tensor<1x32x?x128xf32>{%34}, index) -> tensor<32x128xf32>
    %41 = flow.dispatch @run_forward_dispatch_13::@run_forward_dispatch_13_slow_memcpy[%_global_seq_step.global, %34](%_global_seq_step.global, %39, %34) : (index, tensor<1x32x?x128xf32>{%34}, index) -> tensor<32x128xf32>
    %42 = flow.tensor.reshape %40 : tensor<32x128xf32> -> tensor<1x1x1x32x128xf32>
    %43 = flow.tensor.update %42, %_global_state.0[%c0, %c0, %_global_seq_step.global, %c0, %c0] : tensor<1x1x1x32x128xf32> -> %_global_state.0 as tensor<1x1x4095x32x128xf32>
    %44 = flow.tensor.reshape %41 : tensor<32x128xf32> -> tensor<1x1x1x32x128xf32>
    %45 = flow.tensor.update %44, %_global_state.1[%c0, %c0, %_global_seq_step.global, %c0, %c0] : tensor<1x1x1x32x128xf32> -> %_global_state.1 as tensor<1x1x4095x32x128xf32>

so long as those slow in-place memcpy operations are there (from tensor insert/extract slices it seems) you'll have dynamic allocations. We can make dispatch_12 and dispatch_13 go in-place into the variables but all the others need to not be dispatches for us to do anything.

The profile is almost unchanged actually. The copies that have now become concurrent (transient buffer -> global) are only the static size of the last context layer. The heavy cost comes from copying the entire context into the transient buffer, which is unchanged by splitting the context into 64 globals.

I believe the transient buffer comes from the concat of the context which is used later in a batch_matmul (which was removed in this repro). This concat coming from inside the pretrained model, not anything we are doing in turbine as far as I can tell. Unless we want to make changes to that, then we can't do much at the python level.

So really the purpose of this buffer is to be used by the batch_matmul consumer. If we want to get rid of it, then we will need to have the batch_matmul read directly from the global_state after the global is updated. This is also tricky to do at the python level because the globals are updated after the forward call of the pretrained model, and this would again require us to break open the AutoModelForCausalLM.from_pretrained model.

This seems pretty tricky to fix in the compiler, but let me get some IR to show what I'm talking about more clearly.

qedawkins commented 8 months ago

The profile is almost unchanged actually.

(is the profile small enough to share?)

stellaraccident commented 8 months ago

Changing the model is completely in bounds. Indeed, needing to consume it by the bmm is what forces layout decisions and creates a copy bubble. Look at the math and see if there is a better way and we can fix the model.

If you're having trouble tracking the hf model, this one might be more educational and easier to visualize: https://github.com/stellaraccident/llama.turbine/blob/4c151a7bc4d6e8dc550aa93e73f29f10283b2b40/python/turbine_llamacpp/model.py#L227

Max191 commented 8 months ago

The profile is almost unchanged actually.

(is the profile small enough to share?)

Profile with original (single tensor) context Profile with split (64 tensors) context EDIT: updated split trace with correct link

Max191 commented 8 months ago

Changing the model is completely in bounds. Indeed, needing to consume it by the bmm is what forces layout decisions and creates a copy bubble. Look at the math and see if there is a better way and we can fix the model.

If you're having trouble tracking the hf model, this one might be more educational and easier to visualize: https://github.com/stellaraccident/llama.turbine/blob/4c151a7bc4d6e8dc550aa93e73f29f10283b2b40/python/turbine_llamacpp/model.py#L227

That's good to know. I'm thinking these changes are very doable at python level as long as we can change the model. I think this is where we should start with this.

stellaraccident commented 8 months ago

I expect that at the level of normal ml ops, there is an extra transpose in there that is throwing it all off: you really want to be feeding into a matmul with a transposed hrs. Just guessing, but I think that is causing some obtuse data layout changes and copies in the source.

benvanik commented 8 months ago

something feels real bad with the split trace - are these even the same model? O_O 110ms -> 500ms, nearly all time now in generic ops

if you add --iree-hal-dump-executable-files-to=path/ to your compiler command line and capture again we can see the traces I suspect adding the flow.tensor.slices instead of linalg/tensor ops in the frontend are breaking all kinds of fusion (as linalg can't understand the flow ops)

non-split, all matmul, some bf16 stuff: image

split, all generic stuff, all f32: image

copy-wise, there's significantly less time in split than the original (82ms CPU time -> 15ms CPU time) and significantly less peak memory consumption (11.7GB -> 9.2GB), so it's definitely a win. I think it's causing issues with the rest of the compilation flow.

Max191 commented 8 months ago

something feels real bad with the split trace - are these even the same model? O_O 110ms -> 500ms, nearly all time now in generic ops

if you add --iree-hal-dump-executable-files-to=path/ to your compiler command line and capture again we can see the traces I suspect adding the flow.tensor.slices instead of linalg/tensor ops in the frontend are breaking all kinds of fusion (as linalg can't understand the flow ops)

non-split, all matmul, some bf16 stuff: image

split, all generic stuff, all f32: image

copy-wise, there's significantly less time in split than the original (82ms CPU time -> 15ms CPU time) and significantly less peak memory consumption (11.7GB -> 9.2GB), so it's definitely a win. I think it's causing issues with the rest of the compilation flow.

Uh oh, I uploaded the wrong trace :p Let me fix that

benvanik commented 8 months ago

it's useful to name your traces something meaningful :)

Max191 commented 8 months ago

it's useful to name your traces something meaningful :)

Indeed. The link above is updated with the correct trace.

benvanik commented 8 months ago

which one? they're still named the same :P (it's useful for people not in your head to have clearly named things, dates/versions, etc - I don't know what I'm looking at)

Max191 commented 8 months ago

which one? they're still named the same :P (it's useful for people not in your head to have clearly named things, dates/versions, etc - I don't know what I'm looking at)

Original trace was correct before. The split trace has an updated link.

benvanik commented 8 months ago

I think someone broke source code attachment - sigh - can you try using an absolute path to iree-hal-dump-executable-files-to=?

I think all the cost in both of these is coming from a slice + pack - the slice needs to be fused with the pack, and ideally that pack needs to be propagated back across the globals such that we aren't packing/unpacking each step

I think we're a ways off from needing stream improvements here - I think these are higher level up in the stack - we shouldn't be packing/unpacking and we must fuse slices with packing and that won't happen if the frontend is adding flow slice ops

benvanik commented 8 months ago

the fixes we could make in stream will save us a bit of time; in orig we've got 400us of wall time spent copying back into variables: image

while in split we've got 100us because we can run the copies concurrently: image

but we're talking like <0.1% of total wall time splitting will help with things like propagating packs/slices/etc, but all at a higher level

Max191 commented 8 months ago

I think someone broke source code attachment - sigh - can you try using an absolute path to iree-hal-dump-executable-files-to=?

I think all the cost in both of these is coming from a slice + pack - the slice needs to be fused with the pack, and ideally that pack needs to be propagated back across the globals such that we aren't packing/unpacking each step

I think we're a ways off from needing stream improvements here - I think these are higher level up in the stack - we shouldn't be packing/unpacking and we must fuse slices with packing and that won't happen if the frontend is adding flow slice ops

Yes this was my conclusion as well. We had a discussion about pack propagation across the globals on Tuesday, and concluded we should avoid having to do that if possible, but I think it is necessary now. With the current model, we would need to propagate the pack across both a concat and the global, which is tricky because one of the packs is packing the dynamic dim along which the concat is happening. However, I think if we can change the model in python so that we just update the static max_length context tensor and then load from the same tensor, then we should only have to propagate the pack across a flow.tensor.update which seems more manageable. CC @MaheshRavishankar @bjacob @hanhanW from the discussion on Tuesday.

benvanik commented 8 months ago

yeah, propagating the pack so that we can elide the load->pack->unpack->store sequence is critical, and then for cases where that can't work ensuring we never perform an unfused slice->pack or unpack->update is also critical - no more kicking that can down the road :)

qedawkins commented 8 months ago

Maybe you're already factoring this in, but I just want to add that before jumping straight to propagation, it would be really helpful to hand write the IR that you think we should be generating (i.e. packing the global by hand) to make sure it all connects. Propagation in this case gets really hard because packing of dynamic dimensions implies padding, and pack/unpack propagation with padding is significantly more challenging than aligned propagation.

Max191 commented 8 months ago

Maybe you're already factoring this in, but I just want to add that before jumping straight to propagation, it would be really helpful to hand write the IR that you think we should be generating (i.e. packing the global by hand) to make sure it all connects. Propagation in this case gets really hard because packing of dynamic dimensions implies padding, and pack/unpack propagation with padding is significantly more challenging than aligned propagation.

Sure, some IR would probably be useful regardless. I think we can take advantage of the fact that these globals are uninitialized. Maybe when we propagate the pack we can create an initializer function that just initializes the global to the padding value, and then we don't have to worry about the padding when we load it.

MaheshRavishankar commented 8 months ago

+1 to what @qedawkins said above... details of pack propagation in presence of globals + padding is tricky. Did we try transpose propagation. Lets look at the IR after some data tiling + canonicalization. If all of that looks fine then we can do pack propagation (and will probably end up being very specific to this model).

hanhanW commented 8 months ago

Note that pack propagation could introduce new pack ops that not get fused with other ops, esp. when the producer op has more than one input operand. E.g., https://github.com/llvm/llvm-project/blob/7f0515323670459b8376b866bc73a448f0a5aa6e/mlir/test/Dialect/Linalg/data-layout-propagation.mlir#L150-L199

Propagating a pack op with padding_value is also dangerous, which might introduces undefined behavior.

Can we try the other ideas first? If they don't work, we can revisit the propagation.

I haven't read all the comment yet. It looks like we should try to fuse copy + pack as well?

benvanik commented 8 months ago

Assume that if it is difficult or specific to this model that we need to improve our mechanisms (which may start as annotations). We control the globals, we control the loads/stores to the globals, and we control the slicing/insertion into the globals, and we control the packing/unpacking. If we can't make all of our own stuff line up then we need to go back to the drawing board. I don't think it's that bad, though from what you're all saying you seem to think it is. We should reconcile that.

qedawkins commented 8 months ago

I want to play both sides a little bit here based on my experience with pack/unpack propagation, and also open up the possibility of an IREE ODM on this because we really need to be better on this front.

Note that pack propagation could introduce new pack ops that not get fused with other ops, esp. when the producer op has more than one input operand. E.g., https://github.com/llvm/llvm-project/blob/7f0515323670459b8376b866bc73a448f0a5aa6e/mlir/test/Dialect/Linalg/data-layout-propagation.mlir#L150-L199

Propagating a pack op with padding_value is also dangerous, which might introduces undefined behavior.

+1 to what Hanhan is saying, propagation with padding has a lot of corner cases and if not done carefully can hurt performance fairly significantly. I think there is a path forward, but it will likely require something like a tensor.repack operation to update padding values.

yeah, propagating the pack so that we can elide the load->pack->unpack->store sequence is critical

Big +1 to this. Simple transitive use foldings like this need to have a mechanism for folding away. If the pack and unpack were right next to each other this would just be a folding pattern. We should have the same thing for things like loop carried values in scf.for as well. The question is whether the case here is this kind of "simple" folding (I don't think it is, but we need to look at end state IR).

Did we try transpose propagation. Lets look at the IR after some data tiling + canonicalization. If all of that looks fine then we can do pack propagation

+1 Let's get some end state IR to look at, should help answer the previous two questions. We can take the above linalg as a starting point and just start culling various parts of it not relevent to the core computation (e.g. that chain of index arithmetic above).

hanhanW commented 8 months ago

yeah, propagating the pack so that we can elide the load->pack->unpack->store sequence is critical

Big +1 to this. Simple transitive use foldings like this need to have a mechanism for folding away. If the pack and unpack were right next to each other this would just be a folding pattern.

This is only true when they have the same packing attributes. It mostly happens in propagating unpack op down, which generates some pack ops as well. If this kind of pack op meets other unpack ops, we are able to cancel them. Otherwise, they become transpose + reshape + transpose, which makes things worse. The reassociation maps of reshape op would be very tricky in this case.

benvanik commented 8 months ago

I'm still not clear why (besides not being done yet) this would be complex or padding would matter - we're always slicing out ?x32x128xf32 from 1x1x?x32x128xf32 - 4096 elements (statically shaped) feels like the kind of scope packing should be operating on. We could phrase this as packing the vector in tensor<1x1x?xvector<32x128xf32>> - the outer dims don't matter. Are people talking about general cases here or are we really trying to pack gigabyte tensors on outer dims?

benvanik commented 8 months ago

(that is, we should be able to propagate packs with padding on innermost dims across ops that have static innermost dims that are impacted by the pack pad)

qedawkins commented 8 months ago

(that is, we should be able to propagate packs with padding on innermost dims across ops that have static innermost dims that are impacted by the pack pad)

Not sure why it's not included in the above repro, but we are doing packing on that dynamic dimension. The pack comes from the torch.aten.bmm in the below sequence

%64 = flow.tensor.slice %_global_state.global[%c0, %c0, %c0, %c0, %c0 for %c1, %c1, %_global_seq_step.global, %c32, %c128] : tensor<64x1x4095x32x128xf32> -> tensor<1x1x?x32x128xf32>{%_global_seq_step.global}
%65 = flow.tensor.reshape %64 : tensor<1x1x?x32x128xf32>{%_global_seq_step.global} -> tensor<1x?x32x128xf32>{%_global_seq_step.global_0}
%193 = torch_c.from_builtin_tensor %65 : tensor<1x?x32x128xf32> -> !torch.vtensor<[1,?,32,128],f32>
%257 = torch.aten.transpose.int %193, %int1, %int2 : !torch.vtensor<[1,?,32,128],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,32,?,128],f32>
%404 = torch.aten.add.Tensor %397, %403, %int1 : !torch.vtensor<[1,32,1,128],f32>, !torch.vtensor<[1,32,1,128],f32>, !torch.int -> !torch.vtensor<[1,32,1,128],f32>
%405 = torch.prim.ListConstruct %257, %404 : (!torch.vtensor<[1,32,?,128],f32>, !torch.vtensor<[1,32,1,128],f32>) -> !torch.list<vtensor>
%406 = torch.aten.cat %405, %int2 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[1,32,?,128],f32>
%409 = torch.aten.transpose.int %406, %int2, %int3 : !torch.vtensor<[1,32,?,128],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,32,128,?],f32>
%416 = torch.aten.expand %409, %415, %false : !torch.vtensor<[1,32,128,?],f32>, !torch.list<int>, !torch.bool -> !torch.vtensor<[1,32,128,?],f32>
%418 = torch.aten.view %416, %417 : !torch.vtensor<[1,32,128,?],f32>, !torch.list<int> -> !torch.vtensor<[32,128,?],f32>
%419 = torch.aten.bmm %413, %418 : !torch.vtensor<[32,1,128],f32>, !torch.vtensor<[32,128,?],f32> -> !torch.vtensor<[32,1,?],f32>

That dynamic dimension is the n dimension of that batch matmul and is thus packed. It is outermost in the global so that the torch.aten.cat of the new sequence doesn't require a transient allocation for the full concatenated context (because the concat happens along the dynamic dimension).

benvanik commented 8 months ago

I give up until we have clearer repros/isolation/understand what's happening.