NVIDIA / Fuser

A Fusion Code Generator for NVIDIA GPUs (commonly known as "nvFuser")
Other
257 stars 51 forks source link

Optimize fmax with NAN #319

Open naoyam opened 1 year ago

naoyam commented 1 year ago

Fp max reductions would typically look like:

  for(nvfuser_index_t i154 = 0; i154 < 8; ++i154) {
    int i299;
    i299 = 4 * i154;
#pragma unroll
      for(nvfuser_index_t i156 = 0; i156 < 4; ++i156) {
        T29[0] = fmax(
            T29[0],
            T24[(i299 + i156)]);
      }
  }

Here, fmax is is not just fmaxf, but it also incurs two more comparisons in case the arguments are NAN: https://github.com/NVIDIA/Fuser/blob/main/runtime/helpers.cu#LL102C1-L111C2

This could be translated as:

  bool is_nan = false;
#pragma unroll
  for(nvfuser_index_t i154 = 0; i154 < 8; ++i154) {
    int i299;
    i299 = 4 * i154;
#pragma unroll
      for(nvfuser_index_t i156 = 0; i156 < 4; ++i156) {
#if 0
        T29[0] = fmax(
            T29[0],
            T24[(i299 + i156)]);
#else
        T29[0] = T29[0] > T24[(i299 + i156)] ? T29[0] : T24[(i299 + i156)];
        is_nan = is_nan || isnan(T24[(i299 + i156)]);
#endif
      }
  }
  if (is_nan) {
    T29[0] = NAN;
  }

In the case of cross entropy loss (#278), I observed 20% speedup on A100.

I think this translation could be done automatically as part of lowering. See the translation for welford vectorization.

jacobhinkle commented 1 year ago

This is a good idea to speed up every call to max (and min). Furthermore in the particular case of softmax I think the nan checks could be avoided entirely since even if the max were non-nan the resulting softmax still will be nan due to the later sum, so propagation happens automatically. In that case we'd prefer a true fmaxf-style max computation with no nan check (like torch.fmax instead of torch.maximum). We could provide two different binary ops for each of min and max and add a propagate_nan=true option to the max and min functions, which we would set to false for softmax and log_softmax.

jacobhinkle commented 1 year ago

We could even determine whether nan checks can be skipped automatically without user input, but it involves a bit of a complicated traversal. We would visit each min or max op or reduction and mark the result as unchecked and the input as propagated, then move downstream from further uses of the input and from the min/max output, propagating the tags. When a pointwise binary or ternary op other than min or max has an unchecked input and a propagated input, it resolves the might-be-unchecked flag and becomes properly propagated. If we reach an output or another reduction we stop and if any output is not marked as propagated it means we can't skip nan checks in the original op.

jacobhinkle commented 1 year ago

For the propagation analysis, consider this simple example with no reductions (if we had the non-reduction maximum/minimum ops that do nan-checks):

auto tv2 = maximum(tv0, IrBuilder::create<Float>(0.0));
auto tv3 = add(tv0, tv1);
auto tv4 = minimum(tv3, IrBuilder::create<Float>(10.0));
auto tv5 = sub(tv4, tv2);
fusion->addOutput(tv5);

Currently, we would do nan checks when computing both tv2 and tv4, but we could eliminate the check on tv2 since unpropagated nans would be resolved in tv5. The analysis is similar but a little more complicated when there are reductions and broadcasts like in softmax.

jacobhinkle commented 1 year ago

For the immediate task at hand, I believe we can remove one nan check more easily by removing this branch: https://github.com/NVIDIA/Fuser/blob/main/runtime/helpers.cu#L106-L107. If a == a, then a > b will be false if b != b so the else branch subsumes the b != b branch. Similarly, for fmin we can remove the a != a branch. Doing this, we remove one nan check per reduction element, so it should be equivalent in speed to code in the description of this issue, but doesn't necessitate any codegen changes. As mentioned in my other comments, we can probably remove all the nan checks but I will leave that for a separate PR.

naoyam commented 1 year ago

even if the max were non-nan the resulting softmax still will be nan due to the later sum, so propagation happens automatically.

Why? Can you please elaborate?

jacobhinkle commented 1 year ago

Sure. For log_softmax, we do something like the following:

auto a_max = max(a, {1});
auto a_sub = sub(a, broadcast(a_max, {false, true}));
auto a_exp = exp(a_sub);
auto a_lse = log(sum(a_exp, {1}))
auto log_softmax = sub(a_sub, broadcast(a_lse, {false, true}));

Imagine there is a single nan in some row of a. Ordinarily, this would mean a_max is nan, which would cascade down so that the whole row of a_sub is nan, and ultimately the whole row of log_softmax is nan.

Now imagine we don't check nans for max, so that a_max is not nan in that row for which a contains a nan. Then a_sub contains just that single nan, as does a_exp, since sub and exp propagate nans. But to compute the logsumexp a_lse, we sum a_exp. That sum propagates nans, so that a_lse, and hence log_softmax is nan.

naoyam commented 1 year ago

Oh, I see, very interesting. Have never heard of such a data flow analysis for NANs.

In terms of actual benefits, the transformation as I showed above may be just enough and easier to implement. I saw it got almost the same performance with fmaxf, likely because the root cause of the overhead would be the nan-check branch, which is almost completely eliminated.

jacobhinkle commented 1 year ago

Interesting. Have you tried it without any nan checking? I am not sure how fmaxf works internally but it seems like it would need to branch since it guarantees to return non-nan in case only one input is nan. On the other hand, without nan checks i.e. just a > b ? a : b might return nan or not, we don't care!

naoyam commented 1 year ago

Sorry, I meant I compared the performance with just a > b ? a : b.

jacobhinkle commented 1 year ago

With #329, each call to fmax will also only have one isnan check, but it comes before a > b instead of after. So shouldn't the #if 0 block now give roughly the same performance as the #else code above?

naoyam commented 1 year ago

I actually tried that as well, and it was about the middle of the two cases with about 10% improvement.

jacobhinkle commented 1 year ago

Interesting! Maybe we should re-open and I'll get to work on the original approach.