Open rsuderman opened 1 week ago
Well yes. The concats arent contiguous?
If the concat cannot be made contiguous on ingestion, you can try --iree-opt-outer-dim-concat=true
.
I thought they are right? The concat is happening on the outermost dim. I think the issue is that when this gets decomposed into a tensor.insert_slice
we lose the knowledge that we are inserting into the full length of the dynamic dim.
The choice to decompose tensor.concat
was just because that's where I stopped plumbing through support for it. It should be easy to keep it around and convert it to flow.tensor.update
ops if it's on the outer most dim.
func.func @main(%arg0: tensor<4x?x4x64xf32>, %arg1: tensor<4x?x4x64xf32>) -> tensor<8x?x4x64xf32> {
%1 = tensor.concat dim(0) %arg0, %arg1 : (tensor<4x?x4x64xf32>, tensor<4x?x4x64xf32>) -> tensor<8x?x4x64xf32>
return %1 : tensor<8x?x4x64xf32>
}
Wait isn't this sample degenerate? We always have to copy because the only thing here is a concat.
I think we need some surrounding dispatches to see if this actually always happens.
@qedawkins I think the problem is that it gets lowered to a flow.dispatch + flow.tensor.load + flow.tensor.store
instead of a flow.tensor.update
because isOffsetSizeAndStrideMappableToFlow
returns false. I tried to change https://github.com/iree-org/iree/blob/915b06bb8c3a5471f1a28e743dbe4a403ebc591b/compiler/src/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/Utils.cpp#L111-L120 but ValueBoundsOpInterface
doesn't work with hal buffer dim ops
Sorry, I missed that it is contiguous. @IanWood1 lets look at this deeper today in our 1:1
Again this might be a degenerate case. isOffsetSizeAndStrideMappableToFlow
tries to do some hacky analysis for whether an ssa value is based on a value computed on device. Because the inputs to the concat are function arguments the tensor.dim %arg0
doesn't get replaced with any real ssa values so that helper gives up. That might be the bug here, but this particular test case will always produce memcopies even if we convert to flow.tensor.update
.
I havent looked at this in a while. I'll need some context. Even for this degenerate case we should be able to not generate slow memcpys, even if it means we keep the concat
around longer.
Ok, for now I am taking this from Ian. Ill start with trying to push concat
through further down the compilation flow cause at insert_slice
you dont have information that this is contiguous
The following block always generates a slow memory copy for performing the concatenation