csarofeen / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
http://pytorch.org
Other
26 stars 7 forks source link

Rewrite `reducePredicateRegisterUsage` #2533

Closed zasdfgbnm closed 1 year ago

zasdfgbnm commented 1 year ago

The former approach does not make sense because it does a lot of reordering, even if there are no register usage savings. This reordering can be annoying because it makes the code very hard to read. I am rewriting this pass so that it only reorders things when there is a register saving.

zasdfgbnm commented 1 year ago

Marking this as ready, but I would like to wait for https://github.com/csarofeen/pytorch/pull/2500 because I don't want this to conflict with the new assertCUDAKernels in loop rotation tests.

naoyam commented 1 year ago

when there is a register saving.

I haven't looked into the PR yet, but how do you know if there's register saving?

zasdfgbnm commented 1 year ago

when there is a register saving.

I haven't looked into the PR yet, but how do you know if there's register saving?

I change this pass to only consider register saving on unrolled loop. For example, in threadIdx.x + 3 < T0.size[0], there is no unrolled loop, so changing it to threadIdx.x - T0.size[0] < -3 does not save anything, so now we will not move terms across the < boundary. This pass works by finding all terms that has unrolled loop index dependency, compute its register type, and compare the register type of the remaining terms. If there is a save (that is, remaining has gp register and unroll has uniform or imm, or remaining has uniform and unroll has imm), then move terms.

For example, if I have

#pragma unroll
for i = 0..8:
  threadIdx.x / 128 + i * 32 == 256 + blockIdx.y * i

Then I have register type:

gp,no_unroll + imm,unroll == imm,no_unroll + uniform,unroll

So I will need 8 general purpose register for the left and 8 uniform register for the right.

If I do

gp,no_unroll - imm,no_unroll == uniform,unroll - imm,unroll

Then I will need 1 general purpose register for the left, and 8 uniform register for the right, which saves 7 general purpose registers.

naoyam commented 1 year ago

when there is a register saving.

I haven't looked into the PR yet, but how do you know if there's register saving?

I change this pass to only consider register saving on unrolled loop. For example, in threadIdx.x + 3 < T0.size[0], there is no unrolled loop, so changing it to threadIdx.x - T0.size[0] < -3 does not save anything, so now we will not move terms across the < boundary. This pass works by finding all terms that has unrolled loop index dependency, compute its register type, and compare the register type of the remaining terms. If there is a save (that is, remaining has gp register and unroll has uniform or imm, or remaining has uniform and unroll has imm), then move terms.

For example, if I have

#pragma unroll
for i = 0..8:
  threadIdx.x / 128 + i * 32 == 256 + blockIdx.y * i

Then I have register type:

gp,no_unroll + imm,unroll == imm,no_unroll + uniform,unroll

So I will need 8 general purpose register for the left and 8 uniform register for the right.

If I do

gp,no_unroll - imm,no_unroll == uniform,unroll - imm,unroll

Then I will need 1 general purpose register for the left, and 8 uniform register for the right, which saves 7 general purpose registers.

Thanks for the explanation. Yeah, I found this part (https://github.com/csarofeen/pytorch/pull/2533/files#diff-7853cbfc8ac2e2e18643fb0ba06777e2ee9f4d43065973e9ffb38ebc0fcc0f68R1697), and it makes sense.