Open naoyam opened 1 year ago
This is a good idea to speed up every call to max
(and min
). Furthermore in the particular case of softmax I think the nan checks could be avoided entirely since even if the max were non-nan the resulting softmax still will be nan due to the later sum, so propagation happens automatically. In that case we'd prefer a true fmaxf
-style max computation with no nan check (like torch.fmax
instead of torch.maximum
). We could provide two different binary ops for each of min and max and add a propagate_nan=true
option to the max and min functions, which we would set to false for softmax
and log_softmax
.
We could even determine whether nan checks can be skipped automatically without user input, but it involves a bit of a complicated traversal. We would visit each min or max op or reduction and mark the result as unchecked
and the input as propagated
, then move downstream from further uses of the input and from the min/max output, propagating the tags. When a pointwise binary or ternary op other than min or max has an unchecked
input and a propagated
input, it resolves the might-be-unchecked flag and becomes properly propagated. If we reach an output or another reduction we stop and if any output is not marked as propagated it means we can't skip nan checks in the original op.
For the propagation analysis, consider this simple example with no reductions (if we had the non-reduction maximum
/minimum
ops that do nan-checks):
auto tv2 = maximum(tv0, IrBuilder::create<Float>(0.0));
auto tv3 = add(tv0, tv1);
auto tv4 = minimum(tv3, IrBuilder::create<Float>(10.0));
auto tv5 = sub(tv4, tv2);
fusion->addOutput(tv5);
Currently, we would do nan checks when computing both tv2
and tv4
, but we could eliminate the check on tv2
since unpropagated nans would be resolved in tv5
. The analysis is similar but a little more complicated when there are reductions and broadcasts like in softmax.
For the immediate task at hand, I believe we can remove one nan check more easily by removing this branch: https://github.com/NVIDIA/Fuser/blob/main/runtime/helpers.cu#L106-L107. If a == a
, then a > b
will be false if b != b
so the else branch subsumes the b != b
branch. Similarly, for fmin
we can remove the a != a
branch. Doing this, we remove one nan check per reduction element, so it should be equivalent in speed to code in the description of this issue, but doesn't necessitate any codegen changes. As mentioned in my other comments, we can probably remove all the nan checks but I will leave that for a separate PR.
even if the max were non-nan the resulting softmax still will be nan due to the later sum, so propagation happens automatically.
Why? Can you please elaborate?
Sure. For log_softmax
, we do something like the following:
auto a_max = max(a, {1});
auto a_sub = sub(a, broadcast(a_max, {false, true}));
auto a_exp = exp(a_sub);
auto a_lse = log(sum(a_exp, {1}))
auto log_softmax = sub(a_sub, broadcast(a_lse, {false, true}));
Imagine there is a single nan in some row of a
. Ordinarily, this would mean a_max
is nan, which would cascade down so that the whole row of a_sub
is nan, and ultimately the whole row of log_softmax
is nan.
Now imagine we don't check nans for max, so that a_max
is not nan in that row for which a
contains a nan. Then a_sub
contains just that single nan, as does a_exp
, since sub
and exp
propagate nans. But to compute the logsumexp a_lse
, we sum a_exp
. That sum propagates nans, so that a_lse
, and hence log_softmax
is nan.
Oh, I see, very interesting. Have never heard of such a data flow analysis for NANs.
In terms of actual benefits, the transformation as I showed above may be just enough and easier to implement. I saw it got almost the same performance with fmaxf
, likely because the root cause of the overhead would be the nan-check branch, which is almost completely eliminated.
Interesting. Have you tried it without any nan checking? I am not sure how fmaxf
works internally but it seems like it would need to branch since it guarantees to return non-nan in case only one input is nan. On the other hand, without nan checks i.e. just a > b ? a : b
might return nan or not, we don't care!
Sorry, I meant I compared the performance with just a > b ? a : b
.
With #329, each call to fmax
will also only have one isnan check, but it comes before a > b
instead of after. So shouldn't the #if 0
block now give roughly the same performance as the #else
code above?
I actually tried that as well, and it was about the middle of the two cases with about 10% improvement.
Interesting! Maybe we should re-open and I'll get to work on the original approach.
Fp max reductions would typically look like:
Here, fmax is is not just
fmaxf
, but it also incurs two more comparisons in case the arguments are NAN: https://github.com/NVIDIA/Fuser/blob/main/runtime/helpers.cu#LL102C1-L111C2This could be translated as:
In the case of cross entropy loss (#278), I observed 20% speedup on A100.
I think this translation could be done automatically as part of lowering. See the translation for welford vectorization.