NVIDIA / Fuser

A Fusion Code Generator for NVIDIA GPUs (commonly known as "nvFuser")
Other
257 stars 51 forks source link

Compilation failure using `segment_set` with half precision inputs #1557

Open jacobhinkle opened 9 months ago

jacobhinkle commented 9 months ago

This test fails to compile:

    def test_simple_slice_fusion_bfloat16(self):
        inputs = [torch.randn((10,), dtype=torch.bfloat16, device="cuda:0")]

        def fusion_func(fd: FusionDefinition) -> None:
            T0 = fd.define_tensor(shape=[-1], contiguity=[True], dtype=DataType.BFloat16)
            T1 = fd.ops.neg(T0)
            T2 = fd.ops.segment_set(T1)
            T3 = fd.ops.slice(T2, start_indices=[0], end_indices=[5])
            fd.add_output(T3)

        nvf_out, _ = self.exec_nvfuser(fusion_func, inputs)

        ref = -inputs[0][:5]

        torch.testing.assert_close(nvf_out[0], ref)

Generated kernel and error:

__global__ void nvfuser_pointwise_f0_c1_r0_g0(Tensor<__bfloat, 1, 1> T0, Tensor<__bfloat, 1, 1> T2) {
  nvfuser_index_t i0;
  i0 = ((nvfuser_index_t)threadIdx.x) + (128LL * ((nvfuser_index_t)blockIdx.x));
  if ((i0 < T0.logical_size[0LL])) {
    __bfloat T4[1];
    T4[0] = 0;
    T4[0]
       = T0[i0];
    __bfloat T1[1];
    T1[0]
       = -T4[0];
    __bfloat T5[1];
    T5[0]
       = T1[0];
    T2[i0]
       = T5[0];
  }
}
}

CUDA NVRTC compile error: __tmp_kernel_pointwise_f0_c1_r0_g0.cu(9455): error: no operator "-" matches these operands
            operand types are: - <unnamed>::__bfloat
         = -T4[0];
           ^

This is the first segment which should just be negating the bfloat inputs. It looks like casts are not inserted properly.

In the test, we are placing a segment_set to prevent "forwarding" the input through unary ops. In this case the overloaded terminology "forwarded" means the segmenter is pretending the output of a chain of unary ops from an input is itself the input. It then replays all the unary ops that were forwarded in segments that use that forwarded input. This is usually a good thing but in this case the replaying interferes with the slice op, which requires its inputs to be fusion inputs. So here we insert a segment_set, ensuring that the slice is a segment input.

Related to #1553. This was exposed when testing #1556.

wujingyue commented 7 months ago

We shouldn't have to cast to float before negation. I think the right fix should happen around https://github.com/NVIDIA/Fuser/blob/f681b2fc99b4856966b6f7e116c1a19e22688e39/csrc/codegen.cpp#L711. We probably should use __hneg to negate a bfloat16.

jacobhinkle commented 7 months ago

Good point: neg and abs are probably the two exceptions to the rule when it comes to needing to cast to float to perform arithmetic. This particular bug would still be encountered though if we replaced neg with exp or something like that.

wujingyue commented 7 months ago

That's right. For any operation that's not in https://docs.nvidia.com/cuda/cuda-math-api/group__CUDA__MATH____BFLOAT16__ARITHMETIC.html#group__CUDA__MATH____BFLOAT16__ARITHMETIC_1g93a96c42cf62bbf4d0a9b4b2c268b192, we'll have to make a roundtrip to float.

Question though: who should insert these casts? The user, or the API or an nvFuser pass? Wdyt? I'm leaning towards the user or the API, just to make precision explicit. (It may just be my PTSD after dealing with too many precision/accuracy issues in my prior work).

jacobhinkle commented 7 months ago

Generally we err on the side of accuracy: within a fusion we promote to float to do arithmetic, even if the user specified all Half tensors.

That's right. For any operation that's not in https://docs.nvidia.com/cuda/cuda-math-api/group__CUDA__MATH____BFLOAT16__ARITHMETIC.html#group__CUDA__MATH____BFLOAT16__ARITHMETIC_1g93a96c42cf62bbf4d0a9b4b2c268b192, we'll have to make a roundtrip to float.

I mentioned neg and abs because they are bitwise equal to their roundtrip float equivalents. For these ops you linked, that is doing reduced precision arithmetic and casting the result, so there is loss of precision wrt float if you chain multiple of these together. I think that's why we don't support them (also I guess we don't have much demand for them but if we had a compute-bound kernel that needed it we could put it in).

wujingyue commented 7 months ago

Get it. That's a valid contract as well -- nvFuser always computes using the fp32 precision unless it's certain that using a lower precision preserves the result bitwise. E.g. cast(neg(bfloat)) and neg(cast(bfloat)) are bit-exact, and view(permute(bfloat)) and cast(view(permute(cast(bfloat)))) are bit-exact.