pytorch / TensorRT

PyTorch/TorchScript/FX compiler for NVIDIA GPUs using TensorRT
https://pytorch.org/TensorRT
BSD 3-Clause "New" or "Revised" License
2.41k stars 333 forks source link

❓ [Question] Is there any plan to support bfloat16 compile #2940

Closed leeeizhang closed 1 week ago

leeeizhang commented 2 weeks ago

What you have already tried

The nvidia tensorrt has already support the bf16 precision after tensorrt>=9.2:

However, the latest torch_tensorrt (torch_tensorrt==2.3.0 w/ tensorrt==10.0.1) has not support this.

Is there any plan to support bfloat16 in future verisons? The bf16 is very popular in the LLM inference.

trt_model = torch_tensorrt.compile(
    module=torch.jit.script(model),
    inputs=[torch_tensorrt.Input(shape=(bs, seq, dim), dtype=torch.bfloat16)],
    enabled_precisions={torch.int8, torch.bfloat16, torch.float32},
    calibrator=calibrator,
    device={
        "device_type": torch_tensorrt.DeviceType.GPU,
        "gpu_id": 0,
        "dla_core": 0,
        "allow_gpu_fallback": True,
        "disable_tf32": True,
    },
)
Traceback (most recent call last):
  File "/data01/home/zhanglei.me/workspace/tensorrt_example/example_int8.py", line 38, in <module>
    trt_model = torch_tensorrt.compile(
  File "/usr/local/lib/python3.10/dist-packages/torch_tensorrt/_compile.py", line 208, in compile
    compiled_ts_module: torch.jit.ScriptModule = torchscript_compile(
  File "/usr/local/lib/python3.10/dist-packages/torch_tensorrt/ts/_compiler.py", line 151, in compile
    compiled_cpp_mod = _C.compile_graph(module._c, _parse_compile_spec(spec))
  File "/usr/local/lib/python3.10/dist-packages/torch_tensorrt/ts/_compile_spec.py", line 208, in _parse_compile_spec
    dtype=i.dtype.to(_C.dtype),
  File "/usr/local/lib/python3.10/dist-packages/torch_tensorrt/_enums.py", line 305, in to
    raise TypeError(
TypeError: Provided an unsupported data type as an input data type (support: bool, int32, long, half, float), got: dtype.bf16

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

narendasan commented 1 week ago

The torchscript frontend does not support BF16 however the Dynamo frontend does. If you would like to use torchscript for deployment you can still torch.jit.trace the result of compile with the Dynamo frontend and use it as you would with the original torchscript method

leeeizhang commented 1 week ago

The torchscript frontend does not support BF16 however the Dynamo frontend does. If you would like to use torchscript for deployment you can still torch.jit.trace the result of compile with the Dynamo frontend and use it as you would with the original torchscript method

many thanks! it works for me.