csarofeen / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
http://pytorch.org
Other
26 stars 7 forks source link

Predicate elimination bug #1571

Closed naoyam closed 2 years ago

naoyam commented 2 years ago

This logic seems wrong (blaming myself):

https://github.com/csarofeen/pytorch/blob/devel/torch/csrc/jit/codegen/cuda/lower_predicate.cpp#L343-L348

filters.emplace_back([this](Expr* expr) {
    if (expr->isA<ReductionOp>()) {
      auto input = expr->inputs()[0]->as<TensorView>();
      auto input_def = input->definition();
      // When input_def is null, input must be an input to the fusion,
      // so that must be allocated on global memory. Since we don't omit
      // predication for expressions involving global memory, this
      // should never occur.
      TORCH_INTERNAL_ASSERT(
          input_def != nullptr, "Inconsistent input found: ", input);

      if (non_predicated_exprs_.find(input_def) !=
              non_predicated_exprs_.end() &&
          !(input_def->isA<ReductionOp>() &&
            (expr->as<ReductionOp>()->getReductionOpType() ==
             input_def->as<ReductionOp>()->getReductionOpType()))) {
        return true;
      }
    }
    return false;
  });

Repro:

TEST_F(NVFuserTest, FusionPredicateElimination2_CUDA) {
  Fusion fusion;
  FusionGuard fg(&fusion);

  std::vector<int64_t> shape({10, 11});

  auto tv0 = makeConcreteTensor(shape);
  fusion.addInput(tv0);

  auto tv1 = add(tv0, IrBuilder::create<Double>(1));
  auto tv2 = sum(tv1, {1});
  auto tv3 = add(tv2, IrBuilder::create<Double>(1));

  fusion.addOutput(tv3);

  tv1->split(1, 4);
  tv1->split(0, 4);
  tv2->split(1, 4);
  tv2->split(0, 4);

  fusion.printMath();
  fusion.printKernel();

  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
  auto t0 = at::randn(shape, options);

  FusionExecutor fe;
  fe.compileFusion(&fusion, {t0});
  auto cg_outputs = fe.runFusion({t0});

  auto ref = (t0 + 1).sum({1}) + 1;

  testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__);
}

Validation fails:

Validation error in output 0 on line 17049 in file ../torch/csrc/jit/codegen/cuda/test/test_gpu.cpp.
  Detected abs error of: 1.81327
    absolute tolerance was set to 4.4778e-06
    and relative tolerance set to 4.4778e-08
Exception raised from testValidate at ../torch/csrc/jit/codegen/cuda/test/test_gpu_validator.h:436 (most recent call first):

tv2's predicate is eliminated, which is wrong:

  for(nvfuser_index_t i24 = 0; i24 < (ceilDiv(10, 4)); ++i24) {
    #pragma unroll
    for(nvfuser_index_t i25 = 0; i25 < 4; ++i25) {
      int64_t i106;
      i106 = (i24 * 4) + i25;
      #pragma unroll
      for(nvfuser_index_t i26 = 0; i26 < (ceilDiv(11, 4)); ++i26) {
        #pragma unroll
        for(nvfuser_index_t i27 = 0; i27 < 4; ++i27) {
          T2[i106]
            = T2[i106]
            + T1[(i106 * 11) + ((i26 * 4) + i27)];
        }
      }
    }
  }
naoyam commented 2 years ago

Actually, it's not just reduction. This also fails at the validation due to the predicate of tv2 being eliminated:

TEST_F(NVFuserTest, FusionPredicateElimination3_CUDA) {
  Fusion fusion;
  FusionGuard fg(&fusion);

  auto tv0 = makeConcreteTensor({2, 3});
  fusion.addInput(tv0);

  auto tv1 = add(tv0, IrBuilder::create<Double>(1));
  auto tv2 = add(tv1, IrBuilder::create<Double>(1));
  auto tv3 = add(tv2, IrBuilder::create<Double>(1));
  fusion.addOutput(tv3);

  tv2->split(1, 2);
  tv1->split(1, 2);

  fusion.printMath();
  fusion.printKernel();

  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
  auto t0 = at::randn({2, 3}, options);

  FusionExecutor fe;
  fe.compileFusion(&fusion, {t0});
  auto cg_outputs = fe.runFusion({t0});

  auto ref = t0 + 3;

  testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__);
}

It seems it's been broken ever since we changed how the index is linearlized.

csarofeen commented 2 years ago

In your first example it's the T1 = T0 one that needs to be predicated right?

csarofeen commented 2 years ago

Had to revisit, think I understand the issue now

csarofeen commented 2 years ago

Okay, wow, this took me way too much time to understand. In the second example (since it's simpler) it's the combination of:

  1. the overlapping indices of T1 and T2 based on how we do indexing, using the original stride at the root instead of back propagating extents (should look up the issue where we discussed this)
  2. The fact that T2 is aliased to T1. So when we run:
    auto& T2 = T1;
    #pragma unroll
    for(nvfuser_index_t i25 = 0; i25 < 2; ++i25) {
    #pragma unroll
    for(nvfuser_index_t i26 = 0; i26 < (ceilDiv(3, 2)); ++i26) {
      #pragma unroll
      for(nvfuser_index_t i27 = 0; i27 < 2; ++i27) {
        T2[(i25 * 3) + ((i26 * 2) + i27)]
          = T1[(i25 * 3) + ((i26 * 2) + i27)]
          + (float) 1;
      }
    }
    }

    and we hit overlapping indices at [0, 2, 2] and [1, 0, 0] so:

         T2[(i25 * 3) + ((i26 * 2) + i27)]
          = T1[(i25 * 3) + ((i26 * 2) + i27)]
          + (float) 1;

    Is run twice on that point, which wouldn't be an issue, except it's aliased memory so it's effectively running an in place computation twice.

csarofeen commented 2 years ago

I think there's 3 options here, which you decided to go with the first option in https://github.com/csarofeen/pytorch/pull/1582:

  1. Disable predicate elimination when this type of pattern is detected
  2. Revert the way we do indexing and try to go back to a fully flat indexing pattern
  3. Disable aliasing when we would be able to remove the predicate

I think doing 1 in the short term is a reasonable approach to quickly patch the issue, however I don't know all the trade offs here. 2 was really annoying because of outer splits which meant we would have needed even more predicates in those instances, so maybe this is a bad idea. 3 could be good, but then register pressure may increase.

naoyam commented 2 years ago

@csarofeen Yes, #1582 should disable eliminating required predicates. It's more conservative, but the impact would be minimal in practice. The problem only happens when an IterDomain is split but not parallelized. If parallelized, the domain becomes "zero merged in", so indexing computes the extent by multiplying the extents of child domains. After all, we haven't seen any actual error since the indexing was changed.