NVIDIA / Fuser

A Fusion Code Generator for NVIDIA GPUs (commonly known as "nvFuser")
Other
271 stars 53 forks source link

Enable IdModel tensor indexer when expanded IDs are reshaped #3387

Closed naoyam closed 1 week ago

naoyam commented 1 week ago

Fixes #3299 and #871

The legacy indexer fails when an expanded iter domain is involved in reshape transformations.

naoyam commented 1 week ago

!test --diff

naoyam commented 1 week ago

!test --diff

naoyam commented 1 week ago

The diff check found three cases of code diffs. None of them seems problematic.

  1. Lack of magic zero support. The new indexer doesn't (yet) support the magic zero. We should probably reevaluate if it's still beneficial. Since it's not a functional requirement, I don't think it's a blocker for this fix
  2. Using 0 for iter domains whose extents are known to be 1. For example, here's the diff for AliasTest.MergeTwoExpandedBroadcasts:
    __global__ void nvfuser_N(Tensor<float, 3, 3> T0, Tensor<float, 2, 2> T1) {
    -  nvfuser_index_t i0;
    -  i0 = ((nvfuser_index_t)threadIdx.x) + (128 * ((nvfuser_index_t)blockIdx.x));
    float T2[1];
    T2[0] = 0;
    T2[0]
      = T0[0];
    -  if ((i0 < 20)) {
    +  if ((((nvfuser_index_t)threadIdx.x) < 20)) {
     float T3[1];
     T3[0]
        = T2[0];
    -    T1[i0]
    +    T1[((nvfuser_index_t)threadIdx.x)]
        = T3[0];
    }
    }

For T1, the new indexer just uses threadIdx.x, whereas it's ((nvfuser_index_t)threadIdx.x) + (128 * ((nvfuser_index_t)blockIdx.x)) with the old indexer. This is because the new indexer assigns 0 to the iter domain parallelized by BIDx.

T1_g_float[ iblockIdx.x22{( ceilDiv(( ceilDiv(( ( 4 * 5 ) * 1 ), 128) ), 1) ) ex ( ceilDiv(( ceilDiv(( ( 4 * 5 ) * 6 ), 128) ), 1) )}, iUS23{1}, ithreadIdx.x21{128} ] ca_pos( 2 ) produce_pos( 3 )
 logical domain : (iS15{( 4 * 5 )}, bS16{1 ex 6})
 contiguity: t n
  Merge: iS15{( 4 * 5 )} and bS16{1 ex 6} -> iS19{( ( 4 * 5 ) * 1 ) ex ( ( 4 * 5 ) * 6 )}
  Split: iS19{( ( 4 * 5 ) * 1 ) ex ( ( 4 * 5 ) * 6 )} by factor 128 -> iS20{( ceilDiv(( ( 4 * 5 ) * 1 ), 128) ) ex ( ceilDiv(( ( 4 * 5 ) * 6 ), 128) )}, ithreadIdx.x21{128}
  Split: iS20{( ceilDiv(( ( 4 * 5 ) * 1 ), 128) ) ex ( ceilDiv(( ( 4 * 5 ) * 6 ), 128) )} by factor 1 -> iblockIdx.x22{( ceilDiv(( ceilDiv(( ( 4 * 5 ) * 1 ), 128) ), 1) ) ex ( ceilDiv(( ceilDiv(( ( 4 * 5 ) * 6 ), 128) ), 1) )}, iUS23{1}
 loop domain : (iblockIdx.x22{( ceilDiv(( ceilDiv(( ( 4 * 5 ) * 1 ), 128) ), 1) ) ex ( ceilDiv(( ceilDiv(( ( 4 * 5 ) * 6 ), 128) ), 1) )}, iUS23{1}, ithreadIdx.x21{128})
} 

Since i22's extent can be statically calculated as 1, the index map for the ID is 0, so that's why there's only threadIdx.x in the final index.

  1. A few cases where names of some output tensors got reordered. This is not a new issue. Not sure why, but there still seems to be some non-determinism somewhere.
tests/python/test_ops.py::test_correctness_var_mean_float64
-__global__ void nvfuser_N(Tensor<double, 1, 1> T0, Tensor<double, 0, 0> T8, Tensor<double, 0, 0> T7) {
+__global__ void nvfuser_N(Tensor<double, 1, 1> T0, Tensor<double, 0, 0> T7, Tensor<double, 0, 0> T8) {
   alignas(16) extern __shared__ char array[];
   void* shared_mem = array;
   bool b0;
   b0 = ((nvfuser_index_t)threadIdx.x) == 0;
   Tensor<double, 1, 1> s1;
@@ -48,11 +48,11 @@

   double T1[1];
   T1[0]
     = T9[0]
     / d8;
   if (b0) {
-    T8[0]
+    T7[0]
        = T1[0];
   }
   double T10[1];
   broadcast::blockBroadcast<true, false, false, true>(T10[0], T1[0], static_cast<double*>(shared_mem), true);
   double T2[1];
@@ -74,9 +74,9 @@

   double T4[1];
   T4[0]
     = T2[0]
     * d14;
   if (b0) {
-    T7[0]
+    T8[0]
        = T4[0];
   }
 }
naoyam commented 1 week ago

!test

naoyam commented 1 week ago

!test

naoyam commented 1 week ago

Pinging @jacobhinkle

naoyam commented 1 week ago

!build