NVIDIA / Fuser

A Fusion Code Generator for NVIDIA GPUs (commonly known as "nvFuser")
Other
271 stars 53 forks source link

Only check actually used IDs in predicate elimination #3414

Open jacobhinkle opened 6 days ago

jacobhinkle commented 6 days ago

Now that we can define MmaOp with unbroadcasted inputs (see #3391), it is possible to have ops for which some consumer loop IDs are not used at all for indexing some consumers.

For example, the output of MmaOp has three logical dimensions, M, N, and K. These are scheduled by spitting, merging, and swizzling, so in the end the consumer loop domain can contain things like a split of the N dimension into two other IterDomains. Now if we look at the producer A, it has logical size [M K], so there is no N dimension at all. Our current predicate elimination pass places a predicate on this operation when the N dimension is symbolic and we can't prove that the producer is parallelized the same way as this consumer in this dimension. However, since N cannot affect the indexing of the producer A which has no N dimension, we should skip checking these IterDomains.

This PR does this by performing a BFS from the collection of consumer root IDs that map to producer logical IDs to the consumer leaf domain. Only IDs along that path are checked using the existing conditions.

Detailed example

In the test included in this PR, we have shared memory operand tensors that are scheduled like this

Inputs:
  T0_g___half[ iS0{i0}, iS1{i1} ]
  T1_g___half[ iS2{i3}, iS3{i4} ]
Outputs:
  T3_g___half[ iblockIdx.y27{( ceilDiv(i1, 128) )}, iblockIdx.x29{( ceilDiv(i4, 256) )}, ithreadIdx.y77{2}, ithreadIdx.x111{128}, iS106{32}, iS105{2}, iV109{2} ] ca_pos( 6 ) produce_pos( 6 )

%kernel_math {
T4_s___half[ iblockIdx.y31{( ceilDiv(i1, 128) )}, bblockIdx.x33{1}, iS35{( ceilDiv(i0, 16) )}, bS34{256}, iS43{2}, iB45{2}, iB51{8}, iB48{1}, iB52{8}, iB50{8} ] ca_pos( 3 )
   = CpAsyncBulkTensorTile( T0_g___half[ iS0{i0}, iS1{i1} ] )
T5_s___half[ bblockIdx.y37{1}, iblockIdx.x39{( ceilDiv(i4, 256) )}, iS41{( ceilDiv(i3, 16) )}, bS38{128}, iS53{4}, iB55{2}, iB61{8}, iB58{1}, iB62{8}, iB60{8} ] ca_pos( 3 )
   = CpAsyncBulkTensorTile( T1_g___half[ iS2{i3}, iS3{i4} ] )
T2_l_float[ iblockIdx.y17{( ceilDiv(i1, 128) )}, iblockIdx.x19{( ceilDiv(i4, 256) )}, rS21{( ceilDiv(i0, 16) )}, ithreadIdx.y67{2}, iMMA64{64}, iMMA66{256}, rMMA22{16} ] ca_pos( 2 ) produce_pos( 3 )
   = mma(T4_s___half[ iblockIdx.y31{( ceilDiv(i1, 128) )}, bblockIdx.x33{1}, iS35{( ceilDiv(i0, 16) )}, bS34{256}, iS43{2}, iB45{2}, iB51{8}, iB48{1}, iB52{8}, iB50{8} ] ca_pos( 3 ),
         T5_s___half[ bblockIdx.y37{1}, iblockIdx.x39{( ceilDiv(i4, 256) )}, iS41{( ceilDiv(i3, 16) )}, bS38{128}, iS53{4}, iB55{2}, iB61{8}, iB58{1}, iB62{8}, iB60{8} ] ca_pos( 3 ))
T6_l___half[ iblockIdx.y23{( ceilDiv(i1, 128) )}, iblockIdx.x25{( ceilDiv(i4, 256) )}, ithreadIdx.y72{2}, ithreadIdx.x101{128}, iS96{32}, iS95{2}, iS99{2} ] ca_pos( 6 ) produce_pos( 2 )
   = __float2half(T2_l_float[ iblockIdx.y17{( ceilDiv(i1, 128) )}, iblockIdx.x19{( ceilDiv(i4, 256) )}, rS21{( ceilDiv(i0, 16) )}, ithreadIdx.y67{2}, iMMA64{64}, iMMA66{256}, rMMA22{16} ] ca_pos( 2 ) produce_pos( 3 ));
T3_g___half[ iblockIdx.y27{( ceilDiv(i1, 128) )}, iblockIdx.x29{( ceilDiv(i4, 256) )}, ithreadIdx.y77{2}, ithreadIdx.x111{128}, iS106{32}, iS105{2}, iV109{2} ] ca_pos( 6 ) produce_pos( 6 )
   = Set( T6_l___half[ iblockIdx.y23{( ceilDiv(i1, 128) )}, iblockIdx.x25{( ceilDiv(i4, 256) )}, ithreadIdx.y72{2}, ithreadIdx.x101{128}, iS96{32}, iS95{2}, iS99{2} ] ca_pos( 6 ) produce_pos( 2 ), cache_op=Streaming )
} // %kernel_math

T0_g___half[ iS0{i0}, iS1{i1} ]
 logical domain : (iS0{i0}, iS1{i1})
 contiguity: t t
 loop domain : (iS0{i0}, iS1{i1})
T4_s___half[ iblockIdx.y31{( ceilDiv(i1, 128) )}, bblockIdx.x33{1}, iS35{( ceilDiv(i0, 16) )}, bS34{256}, iS43{2}, iB45{2}, iB51{8}, iB48{1}, iB52{8}, iB50{8} ] ca_pos( 3 )
 logical domain : (iS9{i0}, iS10{i1})
 allocation domain : (iblockIdx.y31{( ceilDiv(i1, 128) )}, bblockIdx.x33{1}, iS35{( ceilDiv(i0, 16) )}, bS34{256}, iS43{2}, iB45{2}, iB51{8}, iB48{1}, iB52{8}, iB50{8})
 contiguity: t n t n t t t t t t
  Split: iS10{i1} by factor 128 -> iblockIdx.y31{( ceilDiv(i1, 128) )}, iS32{128}
  Split: iS9{i0} by factor 16 -> iS35{( ceilDiv(i0, 16) )}, iS36{16}
  Split: iS32{128} by factor 64 -> iS43{2}, iS44{64}
  Split: iS36{16} by factor 8 -> iB45{2}, iS46{8}
  Split: iS46{8} by factor 1 -> iS47{8}, iB48{1}
  Split: iS44{64} by factor 8 -> iS49{8}, iB50{8}
  Xor(2D): iS47{8} , iS49{8} -> iB51{8} , iB52{8}
 loop domain : (iblockIdx.y31{( ceilDiv(i1, 128) )}, bblockIdx.x33{1}, iS35{( ceilDiv(i0, 16) )}, bS34{256}, iS43{2}, iB45{2}, iB51{8}, iB48{1}, iB52{8}, iB50{8})
T1_g___half[ iS2{i3}, iS3{i4} ]
 logical domain : (iS2{i3}, iS3{i4})
 contiguity: t t
 loop domain : (iS2{i3}, iS3{i4})
T5_s___half[ bblockIdx.y37{1}, iblockIdx.x39{( ceilDiv(i4, 256) )}, iS41{( ceilDiv(i3, 16) )}, bS38{128}, iS53{4}, iB55{2}, iB61{8}, iB58{1}, iB62{8}, iB60{8} ] ca_pos( 3 )
 logical domain : (iS11{i3}, iS12{i4})
 allocation domain : (bblockIdx.y37{1}, iblockIdx.x39{( ceilDiv(i4, 256) )}, iS41{( ceilDiv(i3, 16) )}, bS38{128}, iS53{4}, iB55{2}, iB61{8}, iB58{1}, iB62{8}, iB60{8})
 contiguity: n t t n t t t t t t
  Split: iS12{i4} by factor 256 -> iblockIdx.x39{( ceilDiv(i4, 256) )}, iS40{256}
  Split: iS11{i3} by factor 16 -> iS41{( ceilDiv(i3, 16) )}, iS42{16}
  Split: iS40{256} by factor 64 -> iS53{4}, iS54{64}
  Split: iS42{16} by factor 8 -> iB55{2}, iS56{8}
  Split: iS56{8} by factor 1 -> iS57{8}, iB58{1}
  Split: iS54{64} by factor 8 -> iS59{8}, iB60{8}
  Xor(2D): iS57{8} , iS59{8} -> iB61{8} , iB62{8}
 loop domain : (bblockIdx.y37{1}, iblockIdx.x39{( ceilDiv(i4, 256) )}, iS41{( ceilDiv(i3, 16) )}, bS38{128}, iS53{4}, iB55{2}, iB61{8}, iB58{1}, iB62{8}, iB60{8})
T2_l_float[ iblockIdx.y17{( ceilDiv(i1, 128) )}, iblockIdx.x19{( ceilDiv(i4, 256) )}, rS21{( ceilDiv(i0, 16) )}, ithreadIdx.y67{2}, iMMA64{64}, iMMA66{256}, rMMA22{16} ] ca_pos( 2 ) produce_pos( 3 )
 logical domain : (iS4{i1}, iS5{i4}, rS6{i0})
 allocation domain : (iblockIdx.y17{( ceilDiv(i1, 128) )}, iblockIdx.x19{( ceilDiv(i4, 256) )}, rS21{( ceilDiv(i0, 16) )}, ithreadIdx.y67{2}, ithreadIdx.x87{128}, iMMA82{32}, iMMA81{2}, iMMA85{2}, rMMA90{2}, rMMA91{4}, rMMA89{2})
 contiguity: t t n t t t t t n n n
  Split: iS4{i1} by factor 128 -> iblockIdx.y17{( ceilDiv(i1, 128) )}, iS18{128}
  Split: iS5{i4} by factor 256 -> iblockIdx.x19{( ceilDiv(i4, 256) )}, iS20{256}
  Split: rS6{i0} by factor 16 -> rS21{( ceilDiv(i0, 16) )}, rMMA22{16}
  Split: iS18{128} by factor 64 -> iS63{2}, iMMA64{64}
  Split: iS20{256} by factor 256 -> iS65{1}, iMMA66{256}
  Merge: iS63{2} and iS65{1} -> ithreadIdx.y67{2}
 loop domain : (iblockIdx.y17{( ceilDiv(i1, 128) )}, iblockIdx.x19{( ceilDiv(i4, 256) )}, rS21{( ceilDiv(i0, 16) )}, ithreadIdx.y67{2}, iMMA64{64}, iMMA66{256}, rMMA22{16})

Notice that in T4_s that the loop broadcasts bblockIdx.x33{1} and bS34{256} are not derived from the logical domain. Instead, they are actually both the products of a Split involving an original "loop broadcast", although this is not currently shown in fusion->printTransforms():

Split: bS15{1} by factor 256 -> bblockIdx.x33{1}, bS34{256}

In the predicate elimination pass with T4_s and producer and T2_l as consumer, the consumer ID iblockIdx.x19{( ceilDiv(i4, 256) )} normally would map to a logical broadcast ID in T4_s, but with these loop domain broadcasts we do not have such a mapping. Before this PR that would cause predication. This PR notices that iblockIdx.x19{( ceilDiv(i4, 256) )} is not actually used for indexing the producer T4_s so we do not need to worry about out-of-bounds accesses in this dimension.

Without this PR, if we remove the check at https://github.com/NVIDIA/Fuser/blob/3266b9d21cb82272fe6e766b71fb9a9f298de833/csrc/device_lower/analysis/predicate_elimination.cpp#L34-L37 then we generate the following code:

__global__ void nvfuser_none_f0_c0_r0_g0(      
    Tensor<__half, 2, 2> T0,                                                          
    Tensor<__half, 2, 2> T1,                                                          
    const __grid_constant__ TensorMap var0,                                           
    const __grid_constant__ TensorMap var1,                                           
    Tensor<__half, 2, 2> T3) {
  // ...
  nvfuser_index_t i4;
  i4 = 256 * ((nvfuser_index_t)blockIdx.x);
  nvfuser_index_t i7;
  i7 = 128 * ((nvfuser_index_t)blockIdx.y);
  nvfuser_index_t i19;
  i19 = 64 * ((nvfuser_index_t)threadIdx.y);
  bool b20;
  b20 = (i4 < T1.logical_size[1LL]) && ((i19 + i7) < T0.logical_size[1LL]);

#pragma unroll 1
  for (nvfuser_index_t i23 = 0; i23 < i2; ++i23) {
    nvfuser_index_t i24;
    i24 = 16 * i23;

    // ... load operands ...

    __syncthreads();
    if ((b20 && (i24 < T0.logical_size[0LL]))) {
      asm volatile(
          "{\n"
          "  .reg .pred p0; \n"
          "  setp.ne.b32 p0, %130, 0;\n"
          "  wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16 {..." /*... long parameter list ... */);
    }
    asm volatile("wgmma.commit_group.sync.aligned;\n");
    asm volatile("wgmma.wait_group.sync.aligned %0;\n" ::"n"(0LL) : "memory");
  }
  asm volatile("wgmma.commit_group.sync.aligned;\n");
  asm volatile("wgmma.wait_group.sync.aligned %0;\n" ::"n"(0LL) : "memory");
  // ... epilogue and write outputs ...
}

After this PR, the predicate around the wgmma call is removed and the assertOnWarpOps check can be restored.

jacobhinkle commented 6 days ago

!test --diff

jacobhinkle commented 6 days ago

!build

jacobhinkle commented 6 days ago

!test --diff

jacobhinkle commented 3 days ago

!build

jacobhinkle commented 3 days ago

!test

jacobhinkle commented 2 days ago

!test

jacobhinkle commented 2 days ago

!test

jacobhinkle commented 2 days ago

!test

jacobhinkle commented 2 days ago

The two remaining "failures" are tests with IDs of IterType::VectorComponent. These are created by an operation similar to how Broadcast IDs are created in a BroadcastOp, so like the broadcast case we should not be introducing a predicate for them either.

Now that I've verified there are no unexpected cases hitting this code path, I'm going to remove the NVF_THROW and run a codediff. To do that I'll need to disable the test temporarily due to a bug in codediff new or removed tests in the sharded binary tests.

jacobhinkle commented 2 days ago

!test --diff

jacobhinkle commented 2 days ago

I have re-enabled the test now. To summarize:

  1. There were some minor bugs in how the indexing IDs were computed. I fixed these and moved that code into the lower_utils::getIndexIDs utility.
  2. The new check where we return early for IDs that are not index IDs does hit some pre-existing code. However, this only happens for BroadcastOp which introduces new IterType::Broadcast IDs, and for ops like view(DataType dtype) which can introduce new IterType::VectorComponent IDs. Neither of these cases require predication so returning early does not change the functionality.
  3. I have confirmed via codediff there are no changes to our generated code.
jacobhinkle commented 1 day ago

!test

naoyam commented 1 day ago

Looks like you're still making some changes. Is this ready for review or should we wait?

jacobhinkle commented 1 day ago

Looks like you're still making some changes. Is this ready for review or should we wait?

I made a couple changes based on some downstream testing in #3416. Let me run the CI tests again to verify that everything's working then I'll ping you for another review.

jacobhinkle commented 1 day ago

!test

jacobhinkle commented 1 day ago

@naoyam tests pass on Ampere and I verified the new test passes on Hopper, so I think this is ready for review now.

naoyam commented 23 hours ago

Could you give some examples of generated code before and after the predicate elimination? Please also include ID exprs of related tensors.

jacobhinkle commented 6 hours ago

Could you give some examples of generated code before and after the predicate elimination? Please also include ID exprs of related tensors.

I added the example from the added test to the description on this PR. There are no other known examples currently since this PR does not affect predication in the only two other relevant cases:

In the VectorComponent case, we could skip the early return if consumer_id->isVectorComponent() just to be sure we do not change any current behavior.

naoyam commented 4 hours ago

Thanks. I'm concerned about getIndexIDs since it won't work if tensors have loop-promoted IDs, where we would need to look at the whole fusion to get a full loop promotion map.

I wonder why the existing predicate checker can reason about the other producer, T5, but not T4. Which part of the check sets the needs_predicate_ flag before this PR?