Open jacobhinkle opened 9 months ago
We shouldn't have to cast to float before negation. I think the right fix should happen around https://github.com/NVIDIA/Fuser/blob/f681b2fc99b4856966b6f7e116c1a19e22688e39/csrc/codegen.cpp#L711. We probably should use __hneg to negate a bfloat16.
Good point: neg
and abs
are probably the two exceptions to the rule when it comes to needing to cast to float to perform arithmetic. This particular bug would still be encountered though if we replaced neg
with exp
or something like that.
That's right. For any operation that's not in https://docs.nvidia.com/cuda/cuda-math-api/group__CUDA__MATH____BFLOAT16__ARITHMETIC.html#group__CUDA__MATH____BFLOAT16__ARITHMETIC_1g93a96c42cf62bbf4d0a9b4b2c268b192, we'll have to make a roundtrip to float.
Question though: who should insert these casts? The user, or the API or an nvFuser pass? Wdyt? I'm leaning towards the user or the API, just to make precision explicit. (It may just be my PTSD after dealing with too many precision/accuracy issues in my prior work).
Generally we err on the side of accuracy: within a fusion we promote to float to do arithmetic, even if the user specified all Half
tensors.
That's right. For any operation that's not in https://docs.nvidia.com/cuda/cuda-math-api/group__CUDA__MATH____BFLOAT16__ARITHMETIC.html#group__CUDA__MATH____BFLOAT16__ARITHMETIC_1g93a96c42cf62bbf4d0a9b4b2c268b192, we'll have to make a roundtrip to float.
I mentioned neg
and abs
because they are bitwise equal to their roundtrip float equivalents. For these ops you linked, that is doing reduced precision arithmetic and casting the result, so there is loss of precision wrt float if you chain multiple of these together. I think that's why we don't support them (also I guess we don't have much demand for them but if we had a compute-bound kernel that needed it we could put it in).
Get it. That's a valid contract as well -- nvFuser always computes using the fp32 precision unless it's certain that using a lower precision preserves the result bitwise. E.g. cast(neg(bfloat))
and neg(cast(bfloat))
are bit-exact, and view(permute(bfloat))
and cast(view(permute(cast(bfloat))))
are bit-exact.
This test fails to compile:
Generated kernel and error:
This is the first segment which should just be negating the bfloat inputs. It looks like casts are not inserted properly.
In the test, we are placing a
segment_set
to prevent "forwarding" the input through unary ops. In this case the overloaded terminology "forwarded" means the segmenter is pretending the output of a chain of unary ops from an input is itself the input. It then replays all the unary ops that were forwarded in segments that use that forwarded input. This is usually a good thing but in this case the replaying interferes with theslice
op, which requires its inputs to be fusion inputs. So here we insert asegment_set
, ensuring that the slice is a segment input.Related to #1553. This was exposed when testing #1556.