csarofeen / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
http://pytorch.org
Other
26 stars 7 forks source link

Silent wrong result on broadcasting with split and merge #1880

Open zasdfgbnm opened 2 years ago

zasdfgbnm commented 2 years ago

🐛 Describe the bug

TEST_F(NVFuserTest, FusionBroadcastingIndexing_CUDA) {
  Fusion fusion;
  FusionGuard fg(&fusion);

  auto tv0 = makeSymbolicTensor(2);
  auto tv1 = makeSymbolicTensor(1);
  fusion.addInput(tv0);
  fusion.addInput(tv1);
  auto tv2 = set(tv1);
  auto tv3 = broadcast(tv2, {false, true});
  auto tv4 = add(tv3, tv0);
  fusion.addOutput(tv4);

  tv3->split(0, 32);
  tv3->reorder({{1, -1}});
  tv3->split(1, 32);
  tv3->reorder({{2, -1}});
  tv3->merge(2);
  tv3->split(2, 1);
  tv3->split(2, 32);
  tv3->split(2, 8, false);
  tv3->axis(-2)->parallelize(ParallelType::TIDx);
  tv3->axis(-3)->parallelize(ParallelType::TIDy);
  tv3->split(0, 1);
  tv3->axis(1)->parallelize(ParallelType::Unswitch);

  MaxRootDomainInfoSpanningTree tree(tv3);
  TransformPropagator tp(tv3);
  tree.traverse(&tp);
  scheduler_utils::parallelizeAllLike(tv3);
  tv2->axis(2)->parallelize(ParallelType::Unroll);
  InlinePropagator inline_propagator(tv3, -1, ComputeAtMode::MostInlined);
  tree.traverse(&inline_propagator);

  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
  at::Tensor input0 = at::arange(64, options).view({32, 2});
  at::Tensor input1 = at::arange(32, options) * 0.01;

  fusion.printMath();
  fusion.print();
  fusion.printKernel();

  FusionExecutor fe;
  fe.compileFusion(&fusion, {input0, input1});
  auto outputs = fe.runFusion({input0, input1});
  std::cout << outputs[0] << std::endl;

  auto tv_ref = input0 + input1.unsqueeze(1);

  testValidate(
      &fusion, outputs, {input0, input1}, {tv_ref}, __LINE__, __FILE__);
}

Removing the unroll will fix the issue.

Versions

devel

zasdfgbnm commented 2 years ago

Generated code:

__global__ void CUDAGeneratedKernel(Tensor<float, 2> T0, Tensor<float, 1> T1, Tensor<float, 2> T4) {
  NVFUSER_DEFINE_MAGIC_ZERO
  #pragma unroll 1
  for(nvfuser_index_t i64 = 0; i64 < (ceilDiv((ceilDiv(T0.size[0], 32)), 1)); ++i64) {
    if (((((((((nvfuser_index_t)threadIdx.y) < (ceilDiv((ceilDiv((ceilDiv(32, 1)), 32)), 8))) && (((nvfuser_index_t)threadIdx.y) < (ceilDiv((ceilDiv((ceilDiv((32 * 32), 1)), 32)), 8)))) && (((i64 * 32) + (((((ceilDiv((ceilDiv((ceilDiv(32, 1)), 32)), 8)) * 7) + ((nvfuser_index_t)threadIdx.y)) * 32) + ((nvfuser_index_t)threadIdx.x))) < T0.size[0])) && ((((ceilDiv((ceilDiv((ceilDiv(32, 1)), 32)), 8)) * 7) + ((nvfuser_index_t)threadIdx.y)) < (ceilDiv((ceilDiv(32, 1)), 32)))) && (((i64 * 32) + ((((((ceilDiv((ceilDiv((ceilDiv((32 * 32), 1)), 32)), 8)) * 7) + ((nvfuser_index_t)threadIdx.y)) * 32) + ((nvfuser_index_t)threadIdx.x)) / 32)) < T0.size[0])) && (((((ceilDiv(T0.size[1], 32)) - 1) * 32) + ((((((ceilDiv((ceilDiv((ceilDiv((32 * 32), 1)), 32)), 8)) * 7) + ((nvfuser_index_t)threadIdx.y)) * 32) + ((nvfuser_index_t)threadIdx.x)) % 32)) < T0.size[1]))) {
      float T2[(8 * 1)];
      #pragma unroll
      for(nvfuser_index_t i55 = 0; i55 < 8; ++i55) {
        T2[i55] = 0;
      }
      NVFUSER_UPDATE_MAGIC_ZERO
      #pragma unroll
      for(nvfuser_index_t i55 = 0; i55 < 8; ++i55) {
        T2[i55]
           = T1[(((i64 * 32) + (((((i55 + nvfuser_zero) * (ceilDiv((ceilDiv((ceilDiv(32, 1)), 32)), 8))) + ((nvfuser_index_t)threadIdx.y)) * 32) + ((nvfuser_index_t)threadIdx.x))) * T1.stride[0])];
      }
      NVFUSER_UPDATE_MAGIC_ZERO
      #pragma unroll 1
      for(nvfuser_index_t i66 = 0; i66 < (ceilDiv(T0.size[1], 32)); ++i66) {
        #pragma unroll
        for(nvfuser_index_t i67 = 0; i67 < 8; ++i67) {
          int64_t i120;
          i120 = (i64 * 32) + ((((((i67 + nvfuser_zero) * (ceilDiv((ceilDiv((ceilDiv((32 * 32), 1)), 32)), 8))) + ((nvfuser_index_t)threadIdx.y)) * 32) + ((nvfuser_index_t)threadIdx.x)) / 32);
          int64_t i117;
          i117 = (i66 * 32) + ((((((i67 + nvfuser_zero) * (ceilDiv((ceilDiv((ceilDiv((32 * 32), 1)), 32)), 8))) + ((nvfuser_index_t)threadIdx.y)) * 32) + ((nvfuser_index_t)threadIdx.x)) % 32);
          float T3[1];
          T3[0]
             = T2[i67];
          T4[(i120 * T0.size[1]) + i117]
            = T3[0]
            + T0[(i120 * T0.stride[0]) + (i117 * T0.stride[1])];
        }
        NVFUSER_UPDATE_MAGIC_ZERO
      }
    } else {
      float T2[(8 * 1)];
      #pragma unroll
      for(nvfuser_index_t i55 = 0; i55 < 8; ++i55) {
        if ((((nvfuser_index_t)threadIdx.y) < (ceilDiv((ceilDiv((ceilDiv(32, 1)), 32)), 8)))) {
          T2[i55] = 0;
        }
      }
      NVFUSER_UPDATE_MAGIC_ZERO
      #pragma unroll
      for(nvfuser_index_t i55 = 0; i55 < 8; ++i55) {
        int64_t i170;
        i170 = (i64 * 32) + (((((i55 + nvfuser_zero) * (ceilDiv((ceilDiv((ceilDiv(32, 1)), 32)), 8))) + ((nvfuser_index_t)threadIdx.y)) * 32) + ((nvfuser_index_t)threadIdx.x));
        if ((((i170 < T0.size[0]) && ((((i55 + nvfuser_zero) * (ceilDiv((ceilDiv((ceilDiv(32, 1)), 32)), 8))) + ((nvfuser_index_t)threadIdx.y)) < (ceilDiv((ceilDiv(32, 1)), 32)))) && (((nvfuser_index_t)threadIdx.y) < (ceilDiv((ceilDiv((ceilDiv(32, 1)), 32)), 8))))) {
          T2[i55]
             = T1[(i170 * T1.stride[0])];
        }
      }
      NVFUSER_UPDATE_MAGIC_ZERO
      #pragma unroll 1
      for(nvfuser_index_t i66 = 0; i66 < (ceilDiv(T0.size[1], 32)); ++i66) {
        #pragma unroll
        for(nvfuser_index_t i67 = 0; i67 < 8; ++i67) {
          int64_t i201;
          i201 = (i64 * 32) + ((((((i67 + nvfuser_zero) * (ceilDiv((ceilDiv((ceilDiv((32 * 32), 1)), 32)), 8))) + ((nvfuser_index_t)threadIdx.y)) * 32) + ((nvfuser_index_t)threadIdx.x)) / 32);
          int64_t i198;
          i198 = (i66 * 32) + ((((((i67 + nvfuser_zero) * (ceilDiv((ceilDiv((ceilDiv((32 * 32), 1)), 32)), 8))) + ((nvfuser_index_t)threadIdx.y)) * 32) + ((nvfuser_index_t)threadIdx.x)) % 32);
          float T3[1];
          T3[0]
             = T2[i67];
          if ((((i201 < T0.size[0]) && (i198 < T0.size[1])) && (((nvfuser_index_t)threadIdx.y) < (ceilDiv((ceilDiv((ceilDiv((32 * 32), 1)), 32)), 8))))) {
            T4[(i201 * T0.size[1]) + i198]
              = T3[0]
              + T0[(i201 * T0.stride[0]) + (i198 * T0.stride[1])];
          }
        }
        NVFUSER_UPDATE_MAGIC_ZERO
      }
    }
  }
}

This predicate is wrong:

(((nvfuser_index_t)threadIdx.y) < (ceilDiv((ceilDiv((ceilDiv(32, 1)), 32)), 8)))

It should be ...ceilDiv((32 * 32), 1)...

And the index is also wrong:

T1[(((i64 * 32) + (((((i55 + nvfuser_zero) * (ceilDiv((ceilDiv((ceilDiv(32, 1)), 32)), 8))) + ((nvfuser_index_t)threadIdx.y)) * 32) + ((nvfuser_index_t)threadIdx.x))) * T1.stride[0])];

There should be a /32 in the end

zasdfgbnm commented 2 years ago

I am not very familiar with index and predicate calculation, but my guess is, when unrolled, the code that reads T1 will be generated based on T2. But neither T1 nor T2 has the full information about the underlying transformation. The broadcasting dimension appears at T3. So locally looking at the transformation of T1 and T2 when generating code is not sufficient. We should use something with more complete information for predicate and index calculation.

shmsong commented 2 years ago

Just curious does it fail as well if you put all the scheduling on tv4 instead of tv3?

shmsong commented 2 years ago

Just curious does it fail as well if you put all the scheduling on tv4 instead of tv3?

I see it's the unroll on the inner dimension that was the problem so it would still fail I guess.

zasdfgbnm commented 2 years ago

Just curious does it fail as well if you put all the scheduling on tv4 instead of tv3?

Yes, it does

shmsong commented 2 years ago

Problem seems to come from outer split.

Could reproduce the failure with

  auto tv0 = makeConcreteTensor({32,2});
  auto tv1 = makeConcreteTensor({32});
  fusion.addInput(tv0);
  fusion.addInput(tv1);
  auto tv2 = set(tv1);
  auto tv3 = broadcast(tv2, {false, true});
  auto tv4 = add(tv3, tv0);
  fusion.addOutput(tv4);

  tv3->merge(0);
  tv3->split(0, 8, false);

  MaxRootDomainInfoSpanningTree tree(tv3);
  TransformPropagator tp(tv3);
  tree.traverse(&tp);
  InlinePropagator inline_propagator(tv3, -2);
  tree.traverse(&inline_propagator);

The Iterdomains on the right of ca axes after outer split is implicitly concretized so would need to bind the concrete info but it's not loop mapped to other concretized loops.

zasdfgbnm commented 2 years ago

Looks like the outer split might not be the root cause of the problem. I removed the outer split, but the problem is still there:

TEST_F(NVFuserTest, FusionBroadcastingIndexing_CUDA) {
  Fusion fusion;
  FusionGuard fg(&fusion);

  auto tv0 = makeSymbolicTensor(2);
  auto tv1 = makeSymbolicTensor(1);
  fusion.addInput(tv0);
  fusion.addInput(tv1);
  auto tv2 = set(tv1);
  auto tv3 = broadcast(tv2, {false, true});
  auto tv4 = add(tv3, tv0);
  fusion.addOutput(tv4);

  tv3->split(0, 32);
  tv3->reorder({{1, -1}});
  tv3->split(1, 32);
  tv3->reorder({{2, -1}});
  tv3->merge(2);
  tv3->split(2, 1);
  tv3->split(2, 128);
  tv3->axis(-2)->parallelize(ParallelType::TIDx);

  MaxRootDomainInfoSpanningTree tree(tv3);
  TransformPropagator tp(tv3);
  tree.traverse(&tp);
  scheduler_utils::parallelizeAllLike(tv3);
  tv2->axis(-3)->parallelize(ParallelType::Unroll);
  InlinePropagator inline_propagator(tv3, -1, ComputeAtMode::MostInlined);
  tree.traverse(&inline_propagator);

  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
  at::Tensor input0 = at::arange(64, options).view({32, 2});
  at::Tensor input1 = at::arange(32, options) * 0.01;

  fusion.printMath();
  fusion.print();
  fusion.printKernel();

  FusionExecutor fe;
  fe.compileFusion(&fusion, {input0, input1});
  auto outputs = fe.runFusion({input0, input1});
  std::cout << outputs[0] << std::endl;

  auto tv_ref = input0 + input1.unsqueeze(1);

  testValidate(
      &fusion, outputs, {input0, input1}, {tv_ref}, __LINE__, __FILE__);
}
shmsong commented 2 years ago

removing ComputeAtMode::MostInlined made it correct.

The difference seems to be on the ca_pos of T2:

vvv This one failed:

%kernel_math {
T2_l[ iS35{( ceilDiv(i3, 32) )}, iUR39{( ceilDiv(( ceilDiv(32, 1) ), 128) )}, ithreadIdx.x40{128}, iS38{1} ] ca_pos( 1 )
   = T1_g[ iS41{( ceilDiv(i3, 32) )}, iS45{( ceilDiv(( ceilDiv(32, 1) ), 128) )}, iS46{128}, iS44{1} ];
T3_l[ iS8{( ceilDiv(i3, 32) )}, bS10{( ceilDiv(1, 32) )}, iS15{( ceilDiv(( ceilDiv(( 32 * 32 ), 1) ), 128) )}, ithreadIdx.x16{128}, iS14{1} ] ca_pos( 5 ) produce_pos( 1)
   = broadcast( T2_l[ iS35{( ceilDiv(i3, 32) )}, iUR39{( ceilDiv(( ceilDiv(32, 1) ), 128) )}, ithreadIdx.x40{128}, iS38{1} ] ca_pos( 1 ) )
T4_g[ iS17{( ceilDiv(i3, 32) )}, iS19{( ceilDiv(i2, 32) )}, iS24{( ceilDiv(( ceilDiv(( 32 * 32 ), 1) ), 128) )}, ithreadIdx.x25{128}, iS23{1} ] ca_pos( 5 ) produce_pos( 5)
   = T3_l[ iS8{( ceilDiv(i3, 32) )}, bS10{( ceilDiv(1, 32) )}, iS15{( ceilDiv(( ceilDiv(( 32 * 32 ), 1) ), 128) )}, ithreadIdx.x16{128}, iS14{1} ] ca_pos( 5 ) produce_pos( 1)
   + T0_g[ iS26{( ceilDiv(i1, 32) )}, iS28{( ceilDiv(i2, 32) )}, iS33{( ceilDiv(( ceilDiv(( 32 * 32 ), 1) ), 128) )}, iS34{128}, iS32{1} ];
}

vvv This one passed:

%kernel_math {
T2_l[ iS35{( ceilDiv(i3, 32) )}, iUR39{( ceilDiv(( ceilDiv(32, 1) ), 128) )}, ithreadIdx.x40{128}, iS38{1} ] ca_pos( 4 )
   = T1_g[ iS41{( ceilDiv(i3, 32) )}, iS45{( ceilDiv(( ceilDiv(32, 1) ), 128) )}, iS46{128}, iS44{1} ];
T3_l[ iS8{( ceilDiv(i3, 32) )}, bS10{( ceilDiv(1, 32) )}, iS15{( ceilDiv(( ceilDiv(( 32 * 32 ), 1) ), 128) )}, ithreadIdx.x16{128}, iS14{1} ] ca_pos( 5 ) produce_pos( 5)
   = broadcast( T2_l[ iS35{( ceilDiv(i3, 32) )}, iUR39{( ceilDiv(( ceilDiv(32, 1) ), 128) )}, ithreadIdx.x40{128}, iS38{1} ] ca_pos( 4 ) )
T4_g[ iS17{( ceilDiv(i3, 32) )}, iS19{( ceilDiv(i2, 32) )}, iS24{( ceilDiv(( ceilDiv(( 32 * 32 ), 1) ), 128) )}, ithreadIdx.x25{128}, iS23{1} ] ca_pos( 5 ) produce_pos( 5)
   = T3_l[ iS8{( ceilDiv(i3, 32) )}, bS10{( ceilDiv(1, 32) )}, iS15{( ceilDiv(( ceilDiv(( 32 * 32 ), 1) ), 128) )}, ithreadIdx.x16{128}, iS14{1} ] ca_pos( 5 ) produce_pos( 5)
   + T0_g[ iS26{( ceilDiv(i1, 32) )}, iS28{( ceilDiv(i2, 32) )}, iS33{( ceilDiv(( ceilDiv(( 32 * 32 ), 1) ), 128) )}, iS34{128}, iS32{1} ];
}

Sorry didn't paste the right printout earlier. Still looking for the root cause.

shmsong commented 2 years ago

Looks like inner split of inner broadcast has similar issue with outer_split too, if not inlined.


// This allocation size wouldn't be safe, hard to know exactly how many of T2 we'd need without the
//  concrete sizes.
float T2[((ceilDiv((ceilDiv(32, 1)), 128)) * 1)];   

#pragma unroll
for(nvfuser_index_t i41 = 0; i41 < (ceilDiv((ceilDiv(32, 1)), 128)); ++i41) {
  T2[i41] = 0;
}
NVFUSER_UPDATE_MAGIC_ZERO
#pragma unroll
for(nvfuser_index_t i41 = 0; i41 < (ceilDiv((ceilDiv(32, 1)), 128)); ++i41) {
  int64_t i71;
  i71 = (i51 * 32) + (((i41 + nvfuser_zero) * 64) + ((nvfuser_index_t)threadIdx.x));
  if (((i71 < T0.size[0]) && ((((i41 + nvfuser_zero) * 64) + ((nvfuser_index_t)threadIdx.x)) < (ceilDiv(32, 1))))) {
    T2[i41]
       = T1[(i71 * T1.stride[0])];
  }
}
NVFUSER_UPDATE_MAGIC_ZERO
float T3[((ceilDiv((ceilDiv((32 * 32), 1)), 64)) * 1)];
#pragma unroll
for(nvfuser_index_t i44 = 0; i44 < (ceilDiv((ceilDiv((32 * 32), 1)), 64)); ++i44) {
// Indexing of T2 would need to be fixed, but also depending on how T2 is allocated.
  T3[i44]
     = T2[i44];
}
csarofeen commented 2 years ago

Other related repros:

TEST_F(NVFuserTest, FusionBroadcastingIndexingOuter_CUDA) {
  Fusion fusion;
  FusionGuard fg(&fusion);

  auto tv0 = TensorViewBuilder().shape({6, 5}).dtype(DataType::Float).contiguity({true, true}).build();
  auto tv1 = TensorViewBuilder().shape({6}).dtype(DataType::Float).contiguity({true}).build();
  fusion.addInput(tv0);
  fusion.addInput(tv1);
  auto tv2 = set(tv1);
  auto tv3 = broadcast(tv2, {false, true});
  auto tv4 = add(tv3, tv0);
  fusion.addOutput(tv4);

  tv3->merge(0);
  tv3->split(0, 4, false);

  MaxRootDomainInfoSpanningTree tree(tv3);
  TransformPropagator tp(tv3);
  tree.traverse(&tp);
  auto inline_propagator = InlinePropagator(tv3, 1, ComputeAtMode::BestEffort);
  tree.traverse(&inline_propagator);

  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
  at::Tensor input0 = at::arange(6*5, options).view({6, 5});
  at::Tensor input1 = at::arange(6, options) * 0.01;

  fusion.printMath();
  fusion.print();
  fusion.printKernel();

  FusionExecutor fe;
  fe.compileFusion(&fusion, {input0, input1});
  auto outputs = fe.runFusion({input0, input1});
  std::cout << outputs[0] << std::endl;

  auto tv_ref = input0 + input1.unsqueeze(1);

  testValidate(
      &fusion, outputs, {input0, input1}, {tv_ref}, __LINE__, __FILE__);
}

Setting T2 and T3 seems right, but accessing T3 seems wrong. I think T3 indexing should be something like: T3[((i26 * (ceilDiv((6 * 5), 4))) + (i27 + nvfuser_zero))] / 5 / ceilDiv(6, 4) Seems like a zero index propagation issue, or an extent propagation issue.

TEST_F(NVFuserTest, FusionBroadcastingIndexing_CUDA) {
  Fusion fusion;
  FusionGuard fg(&fusion);

  auto tv0 = makeSymbolicTensor(2);
  auto tv1 = makeSymbolicTensor(1);
  fusion.addInput(tv0);
  fusion.addInput(tv1);
  auto tv2 = set(tv1);
  auto tv3 = broadcast(tv2, {false, true});
  auto tv4 = add(tv3, tv0);
  fusion.addOutput(tv4);

  tv3->split(0, 32);
  tv3->reorder({{1, -1}});
  tv3->split(1, 32);
  tv3->reorder({{2, -1}});
  // [Io0, Bo1, Ii0(32), Bi1(32)]
  tv3->merge(2);
  tv3->split(2, 128);
  tv3->axis(-1)->parallelize(ParallelType::TIDx);

  MaxRootDomainInfoSpanningTree tree(tv3);
  TransformPropagator tp(tv3);
  tree.traverse(&tp);
  scheduler_utils::parallelizeAllLike(tv3);
  InlinePropagator inline_propagator(tv3,1);
  tree.traverse(&inline_propagator);

  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
  at::Tensor input0 = at::arange(64, options).view({32, 2});
  at::Tensor input1 = at::arange(32, options) * 0.01;

  fusion.printMath();
  fusion.print();
  fusion.printKernel();

  FusionExecutor fe;
  fe.compileFusion(&fusion, {input0, input1});
  auto outputs = fe.runFusion({input0, input1});

  auto tv_ref = input0 + input1.unsqueeze(1);

  testValidate(
      &fusion, outputs, {input0, input1}, {tv_ref}, __LINE__, __FILE__);
}

Which won't fail if TIDx binding is removed.

Static shaped version will fail parallelization check instead of silent erroring.

csarofeen commented 2 years ago

@shmsong posted a WAR for transpose scheduling on https://github.com/csarofeen/pytorch/pull/1854, so not marking as high priority, but we should quarantine ourselves off from this regime or start deep diving to fix indexing.

csarofeen commented 2 years ago

@zasdfgbnm could you add a case with vectorization as mentioned in https://github.com/csarofeen/pytorch/pull/1918 ? Or do you think there isn't anything specific to vectorization for this bug?

csarofeen commented 2 years ago

Some thoughts in this space: Simple motivating case (all split factors are symbolic, symbolic split factors will be denoted as s0, s1, s2):

tv0{i2}
tv2{i1, i2} = broadcast(tv0, {true, false}) + tv1
tv4{i0, i1, i2} = broadcast(tv2, {true, false, false}) + tv3

tv4->split(2, s2)->split(1, s1)->split(0, s0)
  | tv4{i0//s0, s0, i1//s1, s1, i2//s2, s2}
tv4->merge(2, 4)->merge(0, 2)
  | tv4{(i0//s0)*(i1//s1)*(i2//s2), s0, s1, s2}
tv4->merge(2, 3)->merge(1, 2)
  | tv4{(i0//s0)*(i1//s1)*(i2//s2), s0*s1*s2}

tv0->computeAt(tv4, 1)
The fusion IR looks like:
tv0{                   i2//s2 |, s2}
tv2{         (i1//s1)*(i2//s2)|, s1*s2}
tv4{(i0//s0)*(i1//s1)*(i2//s2)|, s0*s1*s2}

So based on compute at tv0 looks like it's iterating over:

{(i0//s0)*(i1//s1)*(i2//s2), | s2}

So we don't have to generate that many values of tv0. But that's actually wrong. We have to generate every use of tv0 inlined into tv4, so really it should be:

{(i0//s0)*(i1//s1)*(i2//s2)|, s0*s1*s2}

The indexing math annoying here and I might have made minor mistakes but that shouldn't impact the conclusion if so. tv4 would be indexed like:

[
(i/((i1//s1)*(i2//s2))         )*s0 + (j/(s1*s2)),
(i%((i1//s1)*(i2//s2))/(i2//s2))*s1 + (j%(s1*s2)/s2),
(i%((i1//s1)*(i2//s2))%(i2//s2))*s2 + (j%(s1*s2)%s2)
]

We know tv4{i2} exact maps to tv0{i2}, in the indexing of tv4 this dimension is a function of i: (i0//s0) * (i1//s1) * (i2//s2) and j : s0*s1*s2. In other words there's no way to index into tv0 with "just" i: (i0//s0) * (i1//s1) * (i2//s2) and j:s2. So tv0 and tv2 have to "inherit" their loops, even to the right of the loop nest structure, from tv4.

I think this can get even stranger, where a tensor could look like its leaves are 2D, but we would really have to generate and use a 3D loop nest for it. As hard as this seems, I think making it work isn't that bad. Effectively if we "promote" the loops present in the compute at loop map, generate the loop nests from consumers to producers (already do this, but we might need a slightly different logic here), we might be "done" if we then just traverse from the loop nest IDs through exact maps/expressions between exact maps. One part I don't know how to do exactly, is how to index based on the compute at map, while supporting shift and gather.

naoyam commented 2 years ago

Looks like the outer split might not be the root cause of the problem. I removed the outer split, but the problem is still there:

TEST_F(NVFuserTest, FusionBroadcastingIndexing_CUDA) {
  Fusion fusion;
  FusionGuard fg(&fusion);

  auto tv0 = makeSymbolicTensor(2);
  auto tv1 = makeSymbolicTensor(1);
  fusion.addInput(tv0);
  fusion.addInput(tv1);
  auto tv2 = set(tv1);
  auto tv3 = broadcast(tv2, {false, true});
  auto tv4 = add(tv3, tv0);
  fusion.addOutput(tv4);

  tv3->split(0, 32);
  tv3->reorder({{1, -1}});
  tv3->split(1, 32);
  tv3->reorder({{2, -1}});
  tv3->merge(2);
  tv3->split(2, 1);
  tv3->split(2, 128);
  tv3->axis(-2)->parallelize(ParallelType::TIDx);

  MaxRootDomainInfoSpanningTree tree(tv3);
  TransformPropagator tp(tv3);
  tree.traverse(&tp);
  scheduler_utils::parallelizeAllLike(tv3);
  tv2->axis(-3)->parallelize(ParallelType::Unroll);
  InlinePropagator inline_propagator(tv3, -1, ComputeAtMode::MostInlined);
  tree.traverse(&inline_propagator);

  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
  at::Tensor input0 = at::arange(64, options).view({32, 2});
  at::Tensor input1 = at::arange(32, options) * 0.01;

  fusion.printMath();
  fusion.print();
  fusion.printKernel();

  FusionExecutor fe;
  fe.compileFusion(&fusion, {input0, input1});
  auto outputs = fe.runFusion({input0, input1});
  std::cout << outputs[0] << std::endl;

  auto tv_ref = input0 + input1.unsqueeze(1);

  testValidate(
      &fusion, outputs, {input0, input1}, {tv_ref}, __LINE__, __FILE__);
}

Just looked into this fusion. Looks like it's a parallelization problem rather than indexing. Here's the kernel math:

Inputs:
  T0_g[ iS26{( ceilDiv(i1, 32) )}, iS28{( ceilDiv(i2, 32) )}, iS33{( ceilDiv(( ceilDiv(( 32 * 32 ), 1) ), 128) )}, iS34{128}, iS32{1} ], float
  T1_g[ iS41{( ceilDiv(i3, 32) )}, iS45{( ceilDiv(( ceilDiv(32, 1) ), 128) )}, iS46{128}, iS44{1} ], float
Outputs:
  T4_g[ iS17{( ceilDiv(i3, 32) )}, iS19{( ceilDiv(i2, 32) )}, iS24{( ceilDiv(( ceilDiv(( 32 * 32 ), 1) ), 128) )}, ithreadIdx.x25{128}, iS23{1} ] ca_pos( 5 ) produce_pos( 5), float

%kernel_math {
T2_l[ iS35{( ceilDiv(i3, 32) )}, iUR39{( ceilDiv(( ceilDiv(32, 1) ), 128) )}, ithreadIdx.x40{128}, iS38{1} ] ca_pos( 1 )
   = T1_g[ iS41{( ceilDiv(i3, 32) )}, iS45{( ceilDiv(( ceilDiv(32, 1) ), 128) )}, iS46{128}, iS44{1} ];
T3_l[ iS8{( ceilDiv(i3, 32) )}, bS10{( ceilDiv(1, 32) )}, iS15{( ceilDiv(( ceilDiv(( 32 * 32 ), 1) ), 128) )}, ithreadIdx.x16{128}, iS14{1} ] ca_pos( 5 ) produce_pos( 1)
   = broadcast( T2_l[ iS35{( ceilDiv(i3, 32) )}, iUR39{( ceilDiv(( ceilDiv(32, 1) ), 128) )}, ithreadIdx.x40{128}, iS38{1} ] ca_pos( 1 ) )
T4_g[ iS17{( ceilDiv(i3, 32) )}, iS19{( ceilDiv(i2, 32) )}, iS24{( ceilDiv(( ceilDiv(( 32 * 32 ), 1) ), 128) )}, ithreadIdx.x25{128}, iS23{1} ] ca_pos( 5 ) produce_pos( 5)
   = T3_l[ iS8{( ceilDiv(i3, 32) )}, bS10{( ceilDiv(1, 32) )}, iS15{( ceilDiv(( ceilDiv(( 32 * 32 ), 1) ), 128) )}, ithreadIdx.x16{128}, iS14{1} ] ca_pos( 5 ) produce_pos( 1)
   + T0_g[ iS26{( ceilDiv(i1, 32) )}, iS28{( ceilDiv(i2, 32) )}, iS33{( ceilDiv(( ceilDiv(( 32 * 32 ), 1) ), 128) )}, iS34{128}, iS32{1} ];
}

Notice that the TIDX parallelization is done for all tensors, but those parallelized domains are not exactly mapped. This should be fine as long as the domains are on shared memory (or on global memory with grid sync), but we are using the local memory for T2 and T3.

The validation passes if the parallelization is removed. It also passes if T2 is stored on shared memory and a syncthreads is inserted. It isn't actually inserted, so I did one manually, and the validation passed.

So, the first problem is the parallel validation fails to detect the invalid parallelization, which we need to fix.

Another thing to consider is if the current behavior of parallelizeAllLike is desirable. Currently we propagate parallel types from tensors to tensors, but at the same time it should be possible to automatically place tensors on shared or global memories if necessary. In this particular case, t2 should have been automatically placed on shared memory.

Since parallel types and memory types are tightly related, it seems to make sense to update both with parallelizeAllLike, but it may make things just more complex, and the current simpler design may be just better. After all, we don't update the memory type when changing the parallel type of just a single tensor (i.e., tv->axis(0)->parallelize(ParallelType::TIDX) won't automatically place tv on shared memory).

Any thoughts?

csarofeen commented 2 years ago

What does inlining a parallel dimension mean?

naoyam commented 2 years ago

Sorry, what do you mean?

csarofeen commented 2 years ago

Aren't our validation/communication rules dependent on if a parallel dimension is inlined or not? What happens if you most inline your example?

naoyam commented 2 years ago

Yes, derivedFromRootCAAxes is where we process domains differently, but it doesn't catch this case. I should have an updated version pretty soon.

csarofeen commented 2 years ago

What's the difference between processing a parallel dimension as being derivedFromRootCAAxes and not? I have a bigger question here about how we want to treat these kind of cases. Do we want this kind of different behavior for parallel dims? Or should they always just be treated as inlined?

naoyam commented 2 years ago

The problem here is to identify when a producer ID and a consumer ID have the same index. When they are exactly mapped, they are guaranteed to have the same index. It isn't generally the case when they are permissively mapped, but since the producer tensor is actually indexed using the consumer CA domains, the producer ID ends up using the same index in some cases. The question is when exactly that happens.

For example, this is a simplified case based on the above repro by Xiang. T3 is the only consumer of T2.

T2_l[ iS29{( ceilDiv(i3, 32) )}, ithreadIdx.x31{( ceilDiv(32, 16) )}, iS32{16} ] ca_pos( 1 )
 root domain : (iS3{i3})
  Split: iS3{i3} by factor 32 -> iS29{( ceilDiv(i3, 32) )}, iS30{32}, start offset: 0, stop offset: 0
  Split: iS30{32} by factor 16 -> ithreadIdx.x31{( ceilDiv(32, 16) )}, iS32{16}, start offset: 0, stop offset: 0
T3_l[ iS8{( ceilDiv(i3, 32) )}, bS10{( ceilDiv(1, 32) )}, ithreadIdx.x13{( ceilDiv(( 32 * 32 ), 16) )}, iS14{16} ] ca_pos( 4 ) produce_pos( 4)
 root domain : (iS4{i3},bS5{1})
  Split: iS4{i3} by factor 32 -> iS8{( ceilDiv(i3, 32) )}, iS9{32}, start offset: 0, stop offset: 0
  Split: bS5{1} by factor 32 -> bS10{( ceilDiv(1, 32) )}, bS11{32}, start offset: 0, stop offset: 0
  Merge: iS9{32} and bS11{32} -> iS12{( 32 * 32 )}
  Split: iS12{( 32 * 32 )} by factor 16 -> ithreadIdx.x13{( ceilDiv(( 32 * 32 ), 16) )}, iS14{16}, start offset: 0, stop offset: 0

The problem is the forwarding of Merge: iS9{32} and bS11{32} -> iS12{( 32 * 32 )}. It seems a little tricky here as the output ID is not a broadcast anymore and has an extent of 32 x 32. I suspect if the input broadcast ID has an extent of 1, this would just work. Maybe we should also think about not expanding the merge output domain by a broadcast extent, but for now assuming this behavior, the index of T2 is simply something like i * ceildDiv(32, 16) * 16 + j * 16 + k, whereas that of T3 is i * ceilDiv(32 * 32, 16) * 16 + k (ignore threadIdx.x for now). Since the CA position of T2 is just 1, it only shares iS8{( ceilDiv(i3, 32) )} of T3, and since it's exactly mapped with iS29{( ceilDiv(i3, 32) )}, it doesn't affect the index of T2.

Suppose the CA position of T2 is 2. Now the index of T2 is computed also using the domain that corresponds to ithreadIdx.x31{( ceilDiv(32, 16) )} of T2, which means ithreadIdx.x13{( ceilDiv(( 32 * 32 ), 16) )}. With this ID, the non-exactness of the forwarded merge is resolved as T2 effectively uses the same information as T3 to index its root domains. So in this case both T2 and T3 result in using the same index even though they are only mapped permissively.

Overall, the problem here is when we can say two permissively mapped domains have the same index. By definition, they must have a forwarded merge, which can result in different indices. However, when either of the outputs of a forwarded merge is also used for indexing the producer through computeAt, they should end up using the same index.

A similar analysis may be required for trivial reductions as they are basically a reverse operation.

csarofeen commented 1 year ago

I think there's two issues in this thread, (uncertain if 3 below is really in this thread or just a separate thought): (1) We should "demote" broadcast dimensions that are merged in if they're not concretized. This is the example you're talking about @naoyam where effectively we could just pull the broadcast out of the transformations so it isn't materialized, preventing unnecessary expansions through merges. (2) We should "promote" broadcast dimensions when they are materialized, where if we have a broadcast merged in, and that iteration domain maps to a consumer with that dimension, we want to promote the loop mapping to be on the consumer. (3) We might want to consider any parallel dimensions across producers and consumers that are permissively mapped as being "inlined" to prevent parallel communication from being required in situations where there's not a normalization pattern (just an input broadcast being concretized).

csarofeen commented 1 year ago

Extra note on 1: We can't demote broadcast dimensions if their merge is part of reshape.