csarofeen / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
http://pytorch.org
Other
26 stars 7 forks source link

codegen missing fp16 math support #2088

Open jjsjann123 opened 2 years ago

jjsjann123 commented 2 years ago

🐛 Describe the bug

For the issues I'm seeing in our benchmark, I think we can work around it by making batch_norm promotion explicit in the graph. So I'll try to work around it in nvprim/primtorch.

Note: fp16 math can be explicitly put in our python API. Currently we don't have codegen support for these and it's going to be a real problem.

Repro script and error message below.

import torch
from torch._C._nvfuser import FusionDefinition, Fusion, DataType

def nvfuser_fusion(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(symbolic_sizes=[-1, -1, -1, -1], contiguous=[True, True, True, True], dtype=DataType.Half)
    T1 = fd.define_tensor(symbolic_sizes=[-1], contiguous=[True], dtype=DataType.Float)
    T2 = fd.define_tensor(symbolic_sizes=[-1], contiguous=[True], dtype=DataType.Float)
    T3 = fd.define_tensor(symbolic_sizes=[-1], contiguous=[True], dtype=DataType.Float)
    T4 = fd.define_tensor(symbolic_sizes=[-1], contiguous=[True], dtype=DataType.Float)
    S5 = fd.define_constant(0.100000)
    S6 = fd.define_constant(1.00000e-05)
    #T7, T8, T9 = fd.ops.batch_norm(T0, T1, T2, T3, T4, S5, S6)  # note: @kevinstephano python definition gives me an invalid script, which doesn't contain the scalar input to the op here.
    T7, T8, T9 = fd.ops.batch_norm(T0, T1, T2, T3, T4, True, S5, S6, False)
    T10 = fd.ops.cast(T7, dtype=DataType.Float)
    T11 = fd.ops.neg(T10)
    T12 = fd.ops.exp(T11)
    S13 = fd.define_constant(1.00000)
    T14 = fd.ops.add(S13, T12)
    S15 = fd.define_constant(1.00000)
    T16 = fd.ops.div(S15, T14)
    T17 = fd.ops.mul(T10, T16)
    T18 = fd.ops.cast(T17, dtype=DataType.Half)
    fd.add_output(T7)
    fd.add_output(T8)
    fd.add_output(T9)
    fd.add_output(T18)

inputs = [
    torch.randn(2, 3, 8, 8, device='cuda').half(),
    torch.randn(3, device='cuda'),
    torch.randn(3, device='cuda'),
    torch.randn(3, device='cuda'),
    torch.randn(3, device='cuda'),
]

fs = Fusion()
with FusionDefinition(fs) as fd:
    nvfuser_fusion(fd)

for _ in range(5) :
    fs.execute(inputs)

CUDA NVRTC compile error: __tmp_kernel1.cu(4488): error: no operator "/" matches these operands
            operand types are: CudaCodeGen::__half / CudaCodeGen::__half
          detected during instantiation of "void CudaCodeGen::welfordCombine(T &, T &, TN &, T, T, TN) [with T=CudaCodeGen::__half, TN=int]"
(7984): here

__tmp_kernel1.cu(4489): error: no operator "-" matches these operands
            operand types are: const CudaCodeGen::__half - CudaCodeGen::__half
          detected during instantiation of "void CudaCodeGen::welfordCombine(T &, T &, TN &, T, T, TN) [with T=CudaCodeGen::__half, TN=int]"
(7984): here

__tmp_kernel1.cu(4490): error: no operator "*" matches these operands
            operand types are: CudaCodeGen::__half * CudaCodeGen::__half
          detected during instantiation of "void CudaCodeGen::welfordCombine(T &, T &, TN &, T, T, TN) [with T=CudaCodeGen::__half, TN=int]"
(7984): here

__tmp_kernel1.cu(4491): error: no operator "*" matches these operands
            operand types are: CudaCodeGen::__half * CudaCodeGen::__half
          detected during instantiation of "void CudaCodeGen::welfordCombine(T &, T &, TN &, T, T, TN) [with T=CudaCodeGen::__half, TN=int]"
(7984): here

Versions

torchbenchPerf

commit 124fba69b1101cece6d6d0b781c32a11e481dbc7 (HEAD, csarofeen/torchbenchPerf) Author: Ivan Yashchuk ivan.yashchuk@aalto.fi Date: Mon Oct 17 16:34:38 2022 +0300

Revert "Intercept aten._reshape_alias for nvFuser"

This reverts commit eb24dbad03791483fe7a9f7906186dceec6aa6a2.
jjsjann123 commented 2 years ago

cc'ing @kevinstephano on the python_definition issue, which spits out a fusion script that missed the const python scalar arguments to batch_norm.

kevinstephano commented 2 years ago

Are we sure the math issue isn't an indicator that something else is broken given we consciously decided not to implement FP16 math ops in the runtime as they should be upcasted?

kevinstephano commented 2 years ago

Sounds like Jie is going to look into this in the frontend.

jjsjann123 commented 2 years ago

Are we sure the math issue isn't an indicator that something else is broken given we consciously decided not to implement FP16 math ops in the runtime as they should be upcasted?

Sort of yes and no.

The lack of fp16 support is real, when user explicitly define something like the script we have here, codegen is going to scream and fail.

The patch I have in #2104 added explicitly type promotion for nvprims.native_batch_norm, so at least running through dynamo stack, we won't run into these problems any more.

csarofeen commented 2 years ago

I thought codegen auto promotes the operations to FP32 with half inputs. This might be an issue with segmentation.

csarofeen commented 2 years ago

No, I think I'm wrong and we auto demote intermediates to FP16 in segmentation.

jjsjann123 commented 2 years ago

No, I think I'm wrong and we auto demote intermediates to FP16 in segmentation. Sorry I missed you comment @csarofeen , but it only happens when intermediates in the original fusion definition is marked with fp16. Anyway, there's two problem we have with fp16 support:

  1. We need to have dtype correctness on outputs. (linking #2115). We'd be better off have something done in the integration as a catch-it-all safe net.
  2. fp16 math is likely needed as well?! a). Ruberry made an argument on user explicit control of math type. Whether we prioritize it or not is a different question; b). IIUC, fp16 math could be perf critical in longer term.
csarofeen commented 2 years ago

Just to break it down simply for myself there's: (1) Type promotion - To define what precision a computation should take place as (2) Output type - Define what precision the output should be in (1) and (2) don't have to match.

I think there's actually three issues here because of fp16/bf16: (1) Type Promotion - Promotes input types before computations which can impact the output type. This doesn't use, user defined casts, but would use primTorch casts. (2) Output Type - Promotes the output of a computation, and explicitly specifies the output type of a computation. (3) AMP behavior - ...

If we just think about the program involving fp16/bf16, is that for every operator with only fp16 or bf16 inputs we should do: fp16 inputs -> promote inputs to fp32 -> compute -> set output type of fp16

Unless we have an explicit out type (treating cast as a set + specified out type). Then we actually convert that out type to a non fp16/bf16 value.

The optimization we effectively want, is if we have series of the above: fp16 inputs -> promote inputs to fp32 -> compute -> set output type of fp16 -> fp16 inputs -> promote inputs to fp32 -> compute -> set output type of fp16

We want to translate it to: fp16 inputs -> promote inputs to fp32 -> compute -> compute -> set output type of fp16

Side note: this seems to be very similar to what people want/need in quantization.

So the real question here is how do we want to do the above optimization, because nvFuser only has explicit casts, except for the segmentation rule, which may or may not put an intermediate at segmentation boundaries as fp16/bf16 depending on how that value's computed.

jjsjann123 commented 2 years ago

If we just think about the program involving fp16/bf16, is that for every operator with only fp16 or bf16 inputs we should do: fp16 inputs -> promote inputs to fp32 -> compute -> set output type of fp16

I don't have a problem with this as the default behavior. Users / framework might have a stronger opinion on us respecting compute math type, but that's a different conversation and it could also be refactored back when we added fp16 math support.

The optimization we effectively want, is if we have series of the above: fp16 inputs -> promote inputs to fp32 -> compute -> set output type of fp16 -> fp16 inputs -> promote inputs to fp32 -> compute -> set output type of fp16 We want to translate it to: fp16 inputs -> promote inputs to fp32 -> compute -> compute -> set output type of fp16

Do we want to handle this in integration? We are given an FX graph, so we can cancel out neighboring casts, even propagating casts if we want to go extra fancy. I'm also leaning towards having this done inside codegen IR, would transformation like this at codegen be tricky to pull off? cc'ing @naoyam @zasdfgbnm