ROCm / AMDMIGraphX

AMD's graph optimization engine.
https://rocm.docs.amd.com/projects/AMDMIGraphX/en/latest/
MIT License
185 stars 85 forks source link

MLIR: Fuse slice operators with a elementwise #2767

Closed pfultz2 closed 3 months ago

pfultz2 commented 8 months ago

In unet there is this pattern:

@249 = gpu::code_object[code_object=7216,symbol_name=mlir_dot_add,global=1310720,local=256,](@248,@244,@245,@247) -> half_type, {2, 4096, 5120}, {20971520, 5120, 1}, target_id=0
@250 = slice[axes={2},starts={0},ends={2560}](@249) -> half_type, {2, 4096, 2560}, {20971520, 5120, 1}, target_id=0
@251 = load[offset=229826304,end=271769344](@1) -> half_type, {2, 4096, 2560}, {10485760, 2560, 1}, target_id=0
@252 = slice[axes={2},starts={2560},ends={5120}](@249) -> half_type, {2, 4096, 2560}, {20971520, 5120, 1}, target_id=0
@253 = gpu::code_object[code_object=9704,symbol_name=mul_neg_exp_add_div_mul_kernel,global=5242880,local=1024,](@252,@250,@251) -> half_type, {2, 4096, 2560}, {10485760, 2560, 1}, target_id=0

Because of the slices we dont fuse mul_neg_exp_add_div_mul, but we could fold this into the same kernel.

@krzysz00 Is this supported on the MLIR side? I am assuming it is, so we only need to fuse this on the migraphx side.

krzysz00 commented 8 months ago

@pfultz2 To my understanding, this has a solid chance of not working.

Looking at the code, to abstract it out some, what you've got is

C <- dot_add(A, B) : M x N x f16
Cl = slice(C, axis=1, offset=0, length=N') : M x N' x f16
Cr = slice(C, axis=1, offset=N', length=N'): M x N' x f16
Cfused <- mul_...(Cl, Cr, ...) : M x N' x f16

and we'd like to fuse away C

The problem I see with this fusion is that the columns belonging to Cl and the ones belonging to Cr can be located on, for instance, different workgroups, so it'd be, at best, extremely hard to get the communication working to solve the problem as written.

Now, we could imagine the following rewrite

Bl = slice(B, axis=1, offset=0, length=N') : K x N' x f16
B2 = slice(B, axis=1, offset=N', length=N') : K x N' x f16
Cl <- dot_add(A, Bl, ...) : M x N' x f16
Cr <- dot_add(A, Br, ...) : M x N' x f16
Cfused <- mul_...(Cl, Cr, ...) : M x N' x f16

In theory, this is a workable fusion - each lane will have the outputs of both gemms, and, since the gemms will be the same size (and, to preserve our sanity, use the same MFMA/WMMAs), they'll have computed the same segment of the final output. Those segments could then be combined into the final output within registers during that final fusion.

I say "in principle" because a lot of our API surface is operating under the assumption that there's only one (implicit) gemm per offloaded function, so trying to do this fusion would be a great way to smack right into a lot of asserts or strange failures.

So the question becomes: how much do you need this, and would you be OK with us getting back to you on having something like that cleanly supported months from now?

pfultz2 commented 8 months ago

Now, we could imagine the following rewrite

Is this a rewrite we would do? Or you would do this on MLIR side? I guess we could start on the migraphx side at first.

I say "in principle" because a lot of our API surface is operating under the assumption that there's only one (implicit) gemm per offloaded function, so trying to do this fusion would be a great way to smack right into a lot of asserts or strange failures.

I guess we could try it and see what blows up.

how much do you need this, and would you be OK with us getting back to you on having something like that cleanly supported months from now?

I am not sure, I assume if you make this a priority other things need to be put on hold. I was hoping this could be something done with very little changes to MLIR.

With the rewrite you are suggesting it could be possible to do this with CK as its just a group gemm.

krzysz00 commented 8 months ago

With the rewrite you are suggesting it could be possible to do this with CK as its just a group gemm. Or with us. But then you lose the fusion possibilities, because the data for G=0 and G=1 aren't colocated.

But yeah, no, this won't be a simple change on our part. At best, it's wading through several layers of API weirdness to relax some checks / change the behavior when you try to apply a perf config to multiple GEMMs. At worst, you get to late-compile error messages of the form "we don't have enough LDS for that" because each GEMM checks how much LDS it needs in isolation.

So, no, this sort of "combine the results of two independent gemms together at the lane level" fusion isn't something we support and it feels true that we've made some substantially simplifying assumptions (ex. in the regularizer that makes fusion less hairy overall) around the fact that this isn't part of the usecase.

(As to who does the rewrite, it'd make sense for y'all to do it, since you're the ones who'd be recognizing the "elementwise(slice(X), slice(X), ...) where X = dot[_elementwise]" pattern in the first place.)

causten commented 3 months ago

will not work on