ROCm / triton

Development repository for the Triton language and compiler
MIT License
83 stars 27 forks source link

Use full-vectorized load instructions for load vectorization #445

Closed htyu closed 7 months ago

htyu commented 8 months ago

I'm not quite sure why the existing code for load vectorization is using segmented short-vectorized loads instead of using a full 128-bit load. Using multiple copies of shorter load seems to create a dependency on the LLVM backend (esp. the load and store vectorizer) for full vectorization. This might be fragile as I saw in some cases the vector combine pass and the jump threading pass screwed it up and resulted in non-ideal vectorization.

zhanglx13 commented 8 months ago

@htyu Thanks for looking into this issue. However, I don't think this is the right way to solve the problem.

If there is an issue with global load vectorization in your customized kernel, we are happy to help.

htyu commented 8 months ago

Thanks for the comments.

  • We are aware of the facts that in some cases, the LLVM backend cannot (or it has a good reason not to) vectorize global load. For now we need to massage the address computation carefully

What LLVM instructions do you expect to generate with carefully address computation? A full vectorized load or a sequence of shorter loads? I'm also seeing a long load survive the LLVM backend more stably than the latter.

Inline assembly should work. But I'm not sure it has side effects on other LLVM optimizations. What problem do you see with the long load?

zhanglx13 commented 8 months ago

What LLVM instructions do you expect to generate with carefully address computation? A full vectorized load or a sequence of shorter loads? I'm also seeing a long load survive the LLVM backend more stably than the latter.

I expect a sequence of shorter loads. We haven't tried full vectorized global load at llvm level since we are trying to reused as much code from NV path as possible. I agree that inline assembly can have issues related to mem sync. So I'm happy to see that we can avoid using inline assembly to solve the issue.

There are a lot of failed tests. Can you make them pass first?

htyu commented 8 months ago

What LLVM instructions do you expect to generate with carefully address computation? A full vectorized load or a sequence of shorter loads? I'm also seeing a long load survive the LLVM backend more stably than the latter.

I expect a sequence of shorter loads. We haven't tried full vectorized global load at llvm level since we are trying to reused as much code from NV path as possible. I agree that inline assembly can have issues related to mem sync. So I'm happy to see that we can avoid using inline assembly to solve the issue.

There are a lot of failed tests. Can you make them pass first?

Sure, I'll work on clearing the test failures and making sure it's not affecting the NV path.

P.S., the problem I was seeing is that the VectorCombine pass converted the original four i32 loads into four 2xi16 loads. Then the jump threading pass threaded the four 2xi16 loads by getting rid of the redundant mask checks for the first three loads. During the threading, the first three loads were further decomposed into six i16 loads which were vectorized later. The fourth 2x116 load was excluded from the vectorization because it's already vectorized. An 8xi16 load in the first place appeared to be immune from all those issues.

bertmaher commented 8 months ago

We root-caused the bad performance of RMS norm #422 to be due to this issue. It seems like the combination of the for-loop and the control-flow based load masking confuses the load/store vectorizer, and we end up with a dword+dword3 load instead of a dword4 load.

htyu commented 8 months ago

We root-caused the bad performance of RMS norm #422 to be due to this issue. It seems like the combination of the for-loop and the control-flow based load masking confuses the load/store vectorizer, and we end up with a dword+dword3 load instead of a dword4 load.

This is exactly the issue I'm fixing here.

zhanglx13 commented 8 months ago

CC +@scxiao

codego7250 commented 7 months ago

We root-caused the bad performance of RMS norm #422 to be due to this issue. It seems like the combination of the for-loop and the control-flow based load masking confuses the load/store vectorizer, and we end up with a dword+dword3 load instead of a dword4 load.

This is a known issue. Can WA on hand-written code. For compiler generated one, we may consider create a new primitive-based solution later("rely on the user to ensure the correctness"). So far, we are compatible to Triton Nvidia GPU API. And it takes correctness at higher priority than the performance.

htyu commented 7 months ago

We root-caused the bad performance of RMS norm #422 to be due to this issue. It seems like the combination of the for-loop and the control-flow based load masking confuses the load/store vectorizer, and we end up with a dword+dword3 load instead of a dword4 load.

This is a known issue. Can WA on hand-written code. For compiler generated one, we may consider create a new primitive-based solution later("rely on the user to ensure the correctness"). So far, we are compatible to Triton Nvidia GPU API. And it takes correctness at higher priority than the performance.

I think in this case we should still be able to fix the compiler for performance without affecting the correctness.

One path I'm taking is to fix the LLVM GPU load/store vectorizer, where I saw the scalar evolution pass was not able to infer two addresses was consecutive. FYI, https://discourse.llvm.org/t/how-to-compare-scevs/76174 .

Since redundant load masks (relying on the LLVM to be eliminated) where causing issue, I'm also taking another route to avoid generating the redundant checks. Please see if the new version looks reasonable. I'm yet to fix the test failures. The codegen

     %129:4 = scf.if %128 -> (i32, i32, i32, i32) {
      %251 = llvm.addrspacecast %124 : !llvm.ptr<f32, 1> to !llvm.ptr<i32>
      %252 = llvm.load %251 : !llvm.ptr<i32>
      %253 = llvm.addrspacecast %125 : !llvm.ptr<f32, 1> to !llvm.ptr<i32>
      %254 = llvm.load %253 : !llvm.ptr<i32>
      %255 = llvm.addrspacecast %126 : !llvm.ptr<f32, 1> to !llvm.ptr<i32>
      %256 = llvm.load %255 : !llvm.ptr<i32>
      %257 = llvm.addrspacecast %127 : !llvm.ptr<f32, 1> to !llvm.ptr<i32>
      %258 = llvm.load %257 : !llvm.ptr<i32>
      scf.yield %252, %254, %256, %258 : i32, i32, i32, i32
    } else {
      %251 = llvm.mlir.constant(0 : i32) : i32
      %252 = llvm.mlir.constant(0 : i32) : i32
      %253 = llvm.mlir.constant(0 : i32) : i32
      %254 = llvm.mlir.constant(0 : i32) : i32
      scf.yield %251, %252, %253, %254 : i32, i32, i32, i32
    }

    %130 = llvm.bitcast %129#0 : i32 to vector<1xf32>
    %131 = llvm.mlir.constant(0 : index) : i32
    %132 = llvm.extractelement %130[%131 : i32] : vector<1xf32>
    %133 = llvm.bitcast %129#1 : i32 to vector<1xf32>
    %134 = llvm.mlir.constant(0 : index) : i32
    %135 = llvm.extractelement %133[%134 : i32] : vector<1xf32>
    %136 = llvm.bitcast %129#2 : i32 to vector<1xf32>
    %137 = llvm.mlir.constant(0 : index) : i32
    %138 = llvm.extractelement %136[%137 : i32] : vector<1xf32>
    %139 = llvm.bitcast %129#3 : i32 to vector<1xf32>
    %140 = llvm.mlir.constant(0 : index) : i32
    %141 = llvm.extractelement %139[%140 : i32] : vector<1xf32>
codego7250 commented 7 months ago

We root-caused the bad performance of RMS norm #422 to be due to this issue. It seems like the combination of the for-loop and the control-flow based load masking confuses the load/store vectorizer, and we end up with a dword+dword3 load instead of a dword4 load.

This is a known issue. Can WA on hand-written code. For compiler generated one, we may consider create a new primitive-based solution later("rely on the user to ensure the correctness"). So far, we are compatible to Triton Nvidia GPU API. And it takes correctness at higher priority than the performance.

I think in this case we should still be able to fix the compiler for performance without affecting the correctness.

One path I'm taking is to fix the LLVM GPU load/store vectorizer, where I saw the scalar evolution pass was not able to infer two addresses was consecutive. FYI, https://discourse.llvm.org/t/how-to-compare-scevs/76174 .

Since redundant load masks (relying on the LLVM to be eliminated) where causing issue, I'm also taking another route to avoid generating the redundant checks. Please see if the new version looks reasonable. I'm yet to fix the test failures. The codegen

     %129:4 = scf.if %128 -> (i32, i32, i32, i32) {
      %251 = llvm.addrspacecast %124 : !llvm.ptr<f32, 1> to !llvm.ptr<i32>
      %252 = llvm.load %251 : !llvm.ptr<i32>
      %253 = llvm.addrspacecast %125 : !llvm.ptr<f32, 1> to !llvm.ptr<i32>
      %254 = llvm.load %253 : !llvm.ptr<i32>
      %255 = llvm.addrspacecast %126 : !llvm.ptr<f32, 1> to !llvm.ptr<i32>
      %256 = llvm.load %255 : !llvm.ptr<i32>
      %257 = llvm.addrspacecast %127 : !llvm.ptr<f32, 1> to !llvm.ptr<i32>
      %258 = llvm.load %257 : !llvm.ptr<i32>
      scf.yield %252, %254, %256, %258 : i32, i32, i32, i32
    } else {
      %251 = llvm.mlir.constant(0 : i32) : i32
      %252 = llvm.mlir.constant(0 : i32) : i32
      %253 = llvm.mlir.constant(0 : i32) : i32
      %254 = llvm.mlir.constant(0 : i32) : i32
      scf.yield %251, %252, %253, %254 : i32, i32, i32, i32
    }

    %130 = llvm.bitcast %129#0 : i32 to vector<1xf32>
    %131 = llvm.mlir.constant(0 : index) : i32
    %132 = llvm.extractelement %130[%131 : i32] : vector<1xf32>
    %133 = llvm.bitcast %129#1 : i32 to vector<1xf32>
    %134 = llvm.mlir.constant(0 : index) : i32
    %135 = llvm.extractelement %133[%134 : i32] : vector<1xf32>
    %136 = llvm.bitcast %129#2 : i32 to vector<1xf32>
    %137 = llvm.mlir.constant(0 : index) : i32
    %138 = llvm.extractelement %136[%137 : i32] : vector<1xf32>
    %139 = llvm.bitcast %129#3 : i32 to vector<1xf32>
    %140 = llvm.mlir.constant(0 : index) : i32
    %141 = llvm.extractelement %139[%140 : i32] : vector<1xf32>

This looks good. And it may have case for the unit case in terms of the predicate etc. Let's make sure that works for all.

htyu commented 7 months ago

We root-caused the bad performance of RMS norm #422 to be due to this issue. It seems like the combination of the for-loop and the control-flow based load masking confuses the load/store vectorizer, and we end up with a dword+dword3 load instead of a dword4 load.

This is a known issue. Can WA on hand-written code. For compiler generated one, we may consider create a new primitive-based solution later("rely on the user to ensure the correctness"). So far, we are compatible to Triton Nvidia GPU API. And it takes correctness at higher priority than the performance.

I think in this case we should still be able to fix the compiler for performance without affecting the correctness. One path I'm taking is to fix the LLVM GPU load/store vectorizer, where I saw the scalar evolution pass was not able to infer two addresses was consecutive. FYI, https://discourse.llvm.org/t/how-to-compare-scevs/76174 . Since redundant load masks (relying on the LLVM to be eliminated) where causing issue, I'm also taking another route to avoid generating the redundant checks. Please see if the new version looks reasonable. I'm yet to fix the test failures. The codegen

     %129:4 = scf.if %128 -> (i32, i32, i32, i32) {
      %251 = llvm.addrspacecast %124 : !llvm.ptr<f32, 1> to !llvm.ptr<i32>
      %252 = llvm.load %251 : !llvm.ptr<i32>
      %253 = llvm.addrspacecast %125 : !llvm.ptr<f32, 1> to !llvm.ptr<i32>
      %254 = llvm.load %253 : !llvm.ptr<i32>
      %255 = llvm.addrspacecast %126 : !llvm.ptr<f32, 1> to !llvm.ptr<i32>
      %256 = llvm.load %255 : !llvm.ptr<i32>
      %257 = llvm.addrspacecast %127 : !llvm.ptr<f32, 1> to !llvm.ptr<i32>
      %258 = llvm.load %257 : !llvm.ptr<i32>
      scf.yield %252, %254, %256, %258 : i32, i32, i32, i32
    } else {
      %251 = llvm.mlir.constant(0 : i32) : i32
      %252 = llvm.mlir.constant(0 : i32) : i32
      %253 = llvm.mlir.constant(0 : i32) : i32
      %254 = llvm.mlir.constant(0 : i32) : i32
      scf.yield %251, %252, %253, %254 : i32, i32, i32, i32
    }

    %130 = llvm.bitcast %129#0 : i32 to vector<1xf32>
    %131 = llvm.mlir.constant(0 : index) : i32
    %132 = llvm.extractelement %130[%131 : i32] : vector<1xf32>
    %133 = llvm.bitcast %129#1 : i32 to vector<1xf32>
    %134 = llvm.mlir.constant(0 : index) : i32
    %135 = llvm.extractelement %133[%134 : i32] : vector<1xf32>
    %136 = llvm.bitcast %129#2 : i32 to vector<1xf32>
    %137 = llvm.mlir.constant(0 : index) : i32
    %138 = llvm.extractelement %136[%137 : i32] : vector<1xf32>
    %139 = llvm.bitcast %129#3 : i32 to vector<1xf32>
    %140 = llvm.mlir.constant(0 : index) : i32
    %141 = llvm.extractelement %139[%140 : i32] : vector<1xf32>

This looks good. And it may have case for the unit case in terms of the predicate etc. Let's make sure that works for all.

Thanks. But on the second thought, I'm inclined to generating a full vectorized load when possible. This should make it more immune to the llvm uncertainty. It also reduces the size of LLVM IR to improve compile time. Please check my latest version and see if it looks good.

zhanglx13 commented 7 months ago

@htyu Thanks for fixing this issue. And I tested on MI250 and MI300 that we don't need this trick anymore: https://github.com/ROCmSoftwarePlatform/triton/blob/e7033218d6a0f0f1129aa3adc1bfbbe57c84fd20/python/tutorials/03-matrix-multiplication.py#L253-L257

cc+ @scxiao

htyu commented 7 months ago

@htyu Thanks for fixing this issue. And I tested on MI250 and MI300 that we don't need this trick anymore:

https://github.com/ROCmSoftwarePlatform/triton/blob/e7033218d6a0f0f1129aa3adc1bfbbe57c84fd20/python/tutorials/03-matrix-multiplication.py#L253-L257

cc+ @scxiao

Thanks for giving a try!

BTW, how should I land this patch?

zhanglx13 commented 7 months ago

@htyu I'll land it. One more thing, do you think this method also works for nv path? If so, it'll be better if we can merge the two paths.

htyu commented 7 months ago

@htyu I'll land it. One more thing, do you think this method also works for nv path? If so, it'll be better if we can merge the two paths.

I'll need to take a deeper look. NV loads come with those cache flags and I'm not sure how to express them on LLVM dialect. But yeah, I'm in general not in favor of using asm volatiles. It'd be great to get rid of them.

zhanglx13 commented 7 months ago

@htyu Sounds good. Keep us posted !

zhanglx13 commented 5 months ago

@htyu Since we are moving our dev work upstream and closing the perf gap between this fork and upstream, could you please upstream this PR?

htyu commented 5 months ago

@htyu Since we are moving our dev work upstream and closing the perf gap between this fork and upstream, could you please upstream this PR?

Sure, will do.

Do you need me to upstream other PRs I made in this repo?

zhanglx13 commented 5 months ago

Yes, that would be great. Thank you very much ~ I think we can hold dot-slicing related PRs for the moment, since they are still experimental for now.

htyu commented 5 months ago

Upstreaming PR: https://github.com/openai/triton/pull/3609