NVIDIA / Fuser

A Fusion Code Generator for NVIDIA GPUs (commonly known as "nvFuser")
Other
271 stars 53 forks source link

Reject fusions with CPU scalar outputs in `KernelExecutor` #3403

Closed Priya2698 closed 2 days ago

Priya2698 commented 1 week ago

Issue #2853 surfaced a case where adding CPU scalar tensor to scalar value generates a CUDA tensor. This is because nvFuser cannot generate CPU tensors. This PR adds a check to identify any CPU scalar outputs in KernelExecutor::supported to surface such unexpected cases.

Priya2698 commented 1 week ago

!test

Priya2698 commented 1 week ago

!test

Priya2698 commented 1 week ago

Since we could call setCpuTensor at any time either during or after the Fusion has been defined, does it make sense to propagate that setting in a preseg pass?

Good point. We currently have 3 cases where we see CPU scalars:

  1. Fusion inputs - They will be set as CPU scalars at the time of definition.
  2. SdpaFwdOp - philox_seed and philox_offset are created as CPU scalar tensor outputs (https://github.com/NVIDIA/Fuser/blob/7b9271640b3b918b5cb65d9c68ea2c17d14319d2/csrc/ops/composite.cpp#L493-L496)
  3. SdpaBwdOp - Inputs philox_seed and philox_offset are either intermediate tensors generated from SdpaFwdOp or fusion inputs.

For all these cases, the relevant tensors are set as cpu scalars during definition before compileKernel is called, where the KernelExecutor::supported check will be exercised.

What scenario are you considering which would require a preseg pass propagation?

jacobhinkle commented 1 week ago

What scenario are you considering which would require a preseg pass propagation?

All three cases above would potentially be handled properly if we propagate this property in the definition ops. For example

tv0->setCpuScalar()
tv1 = exp(tv0)
// what does tv1->isCpuScalar() return here?

Since we trust that this flag is set properly on fusion inputs, instead of relying on the ops to manage this at definition (and for this flag to survive all our rewriting passes), we could just propagate it at preseg like so:

void propagateCpuScalars(Fusion* fusion) {
  for (Expr* expr : StmtSort::getExprs(fusion)) {
    bool has_cpu_scalar_input = false;
    bool has_cuda_input = false;
    for (Val* inp : expr->inputs()) {
      if (auto* tv : dynamic_cast<TensorView*>(inp)) {
        if (tv->isCpuScalar()) {
          has_cpu_scalar_input = true;
        } else {
          has_cuda_input = true;
        }
      }
    }
    if (has_cpu_scalar_input && !has_cuda_input) {
      // Mark all TV outputs as CPU scalars
      for (Val* outp : expr->outputs()) {
        if (auto* tv : dynamic_cast<TensorView*>(inp)) {
          tv->setCpuScalar();
        }
      }
    }
  }
}
Priya2698 commented 6 days ago

What scenario are you considering which would require a preseg pass propagation?

All three cases above would potentially be handled properly if we propagate this property in the definition ops. For example

tv0->setCpuScalar()
tv1 = exp(tv0)
// what does tv1->isCpuScalar() return here?

Since we trust that this flag is set properly on fusion inputs, instead of relying on the ops to manage this at definition (and for this flag to survive all our rewriting passes), we could just propagate it at preseg like so:

void propagateCpuScalars(Fusion* fusion) {
  for (Expr* expr : StmtSort::getExprs(fusion)) {
    bool has_cpu_scalar_input = false;
    bool has_cuda_input = false;
    for (Val* inp : expr->inputs()) {
      if (auto* tv : dynamic_cast<TensorView*>(inp)) {
        if (tv->isCpuScalar()) {
          has_cpu_scalar_input = true;
        } else {
          has_cuda_input = true;
        }
      }
    }
    if (has_cpu_scalar_input && !has_cuda_input) {
      // Mark all TV outputs as CPU scalars
      for (Val* outp : expr->outputs()) {
        if (auto* tv : dynamic_cast<TensorView*>(inp)) {
          tv->setCpuScalar();
        }
      }
    }
  }
}

I see. This is a good alternative. So instead of going through the fusion expressions, like the current approach, we can query this flag for the fusion->outputs(). It should work for SdpaFwdOp as well, since the relevant outputs are set as CPU scalars by the op. I do not have a preference on which implementation to go with -- from the overhead consideration, both approach should be fine.

instead of relying on the ops to manage this at definition (and for this flag to survive all our rewriting passes)

If the passes rewrite this flag accidentally, setting the flags for outputs will also be error-prone. What do you mean by relying on the ops to manage this at definition -- is it that in the example above, the flag for tv1 will not be set?

Priya2698 commented 4 days ago

!test

Priya2698 commented 3 days ago

What scenario are you considering which would require a preseg pass propagation?

All three cases above would potentially be handled properly if we propagate this property in the definition ops. For example

tv0->setCpuScalar()
tv1 = exp(tv0)
// what does tv1->isCpuScalar() return here?

Since we trust that this flag is set properly on fusion inputs, instead of relying on the ops to manage this at definition (and for this flag to survive all our rewriting passes), we could just propagate it at preseg like so:

void propagateCpuScalars(Fusion* fusion) {
  for (Expr* expr : StmtSort::getExprs(fusion)) {
    bool has_cpu_scalar_input = false;
    bool has_cuda_input = false;
    for (Val* inp : expr->inputs()) {
      if (auto* tv : dynamic_cast<TensorView*>(inp)) {
        if (tv->isCpuScalar()) {
          has_cpu_scalar_input = true;
        } else {
          has_cuda_input = true;
        }
      }
    }
    if (has_cpu_scalar_input && !has_cuda_input) {
      // Mark all TV outputs as CPU scalars
      for (Val* outp : expr->outputs()) {
        if (auto* tv : dynamic_cast<TensorView*>(inp)) {
          tv->setCpuScalar();
        }
      }
    }
  }
}

There is an issue in propagating this flag to tensor outputs. Consider

        def fusion_func(fd: FusionDefinition):
            t0 = fd.from_pytorch(inputs[0])
            s0 = fd.from_pytorch(inputs[1])
            t1 = fd.ops.add(t0, s0)
            fd.add_output(t1)

Here, t0 will have an implicit broadcast.

T2_l_float[ bS1{1} ]
   = broadcast( T0_g_float[ ] )

The output of this broadcast will be inferred as CPU scalar, but that will raise an error since we do not allow CPU non-scalar tensors. As a whole, this fusion produces a CUDA tensor output so is a valid fusion.

Priya2698 commented 3 days ago

!test

jacobhinkle commented 3 days ago

The output of this broadcast will be inferred as CPU scalar, but that will raise an error since we do not allow CPU non-scalar tensors. As a whole, this fusion produces a CUDA tensor output so is a valid fusion.

I see, so the failure occurs when tv->setCpuScalar() is called in the preseg pass in that case. Maybe the condition should be TensorDomain::noBroadcasts(getLogicalDomain()).empty() instead of nDims() == 0?

Priya2698 commented 3 days ago

The output of this broadcast will be inferred as CPU scalar, but that will raise an error since we do not allow CPU non-scalar tensors. As a whole, this fusion produces a CUDA tensor output so is a valid fusion.

I see, so the failure occurs when tv->setCpuScalar() is called in the preseg pass in that case. Maybe the condition should be TensorDomain::noBroadcasts(getLogicalDomain()).empty() instead of nDims() == 0?

Yes. If we choose to support returning CPU outputs through expression evaluator, we would need the analysis to be within the preseg pass and the protocol to be modified, i.e. TensorDomain::noBroadcasts(getLogicalDomain()).empty() instead of nDims() == 0. I am not sure of the original motivation to allow CPU scalar tensors alone (and consequently, if allowing broadcasted/expanded scalar tensors is an issue), though and if we can support CPU inputs more generally (currently, there isn't a strong use-case AFAIK)

I can merge this PR as-is and open an issue for the above, wdyt?

Priya2698 commented 3 days ago

!test

jacobhinkle commented 2 days ago

I can merge this PR as-is and open an issue for the above, wdyt?

Yes, I think if this issue is not currently causing any breakage then we should do that to get you unblocked.