NVIDIA / trt-samples-for-hackathon-cn

Simple samples for TensorRT programming
Apache License 2.0
1.47k stars 337 forks source link

arange bug in tensorrt 9.0.0 (Hackathon 2023) #89

Closed Tlntin closed 3 months ago

Tlntin commented 12 months ago

Environment

If applicable, please include the following: CPU architecture: x86_64 GPU name: NVIDIA A10 TensorRT branch: 9.0.0 TensorRT LLM: 0.1.3 Cuda: 12.1.66 Cudnn: 8.9.0 Container: registry.cn-hangzhou.aliyuncs.com/trt-hackathon/trt-hackathon:final_v1 NVIDIA driver version: 525.105.17 OS: Ubuntu 22.04.3 LTS x86_64 Kernel: 5.15.0-73-generic

复现code

import unittest

import numpy as np
import torch
from polygraphy.backend.trt import EngineFromNetwork, TrtRunner

import tensorrt_llm
from tensorrt_llm import Tensor
# import torch
import numpy as np
from tensorrt_llm.functional import Tensor, arange

class TestFunctional(unittest.TestCase):

    def setUp(self):
        tensorrt_llm.logger.set_level('error')

    def test_case(self):
        dtype = 'int32'
        seq_len = 3072
        input_length = torch.tensor(seq_len, dtype=torch.int32)

        builder = tensorrt_llm.Builder()
        net = builder.create_network()
        with tensorrt_llm.net_guard(net):
            network = tensorrt_llm.default_trtnet()
            input_len = Tensor(name='input_len',
                       shape=input_length.shape,
                       dtype=tensorrt_llm.str_dtype_to_trt(dtype))
            trt_seq = arange(0, seq_len * 2, dtype="int32")
            output = trt_seq.trt_tensor
            output.name = 'output'
            network.mark_output(output)

        # for old pytorch
        seq = torch.arange(0, seq_len * 2).int()
        # for new
        build_engine = EngineFromNetwork((builder.trt_builder, net.trt_network))
        with TrtRunner(build_engine) as runner:
            outputs = runner.infer(feed_dict={"input_len": input_length.numpy()})
        output = outputs['output']
        np.testing.assert_allclose(seq, output, atol=1e-5)

if __name__ == "__main__":
    unittest.main()

报错信息:

Traceback (most recent call last):
  File "/root/workspace/trt2023/tensorrt_llm_july-release-v1/examples/qwen/test_caset2.py", line 32, in test_case
    trt_seq = arange(0, seq_len * 2, dtype="int32")
  File "/usr/local/lib/python3.8/dist-packages/tensorrt_llm/functional.py", line 918, in arange
    layer.set_alpha(start)
AttributeError: 'tensorrt.tensorrt.IFillLayer' object has no attribute 'set_alpha'

错误原因: 从arange的实现来看,当输入两个int类型的时候,调用了layer.set_alpha和layer.set_beta两个函数。 结合英伟达tensorRT文档,链接1链接2来看,目前只有set_input方法。观察下面arange的实现来看:

def arange(start: Union[Tensor, int], end: Union[Tensor, int],
           dtype: str) -> Tensor:
    '''
    Add an operation to fill a 1D tensor.

    The tensor is filled with the values between start and end with a step of 1
    between the different elements. In pseudo-code, it corresponds to a tensor
    populated with the values:

        output = Tensor([dtype(ii) for ii in range(start, end, 1)])

    For example, a call to arange(3, 6, 'int32') will add an operation to the
    TensorRT graph that will produce [3, 4, 5] when executed. The call to
    arange(2, 5, 'float32') will add a layer to generate [2.0, 3.0, 4.0].

    This operation is implemented using a tensorrt.IFillLayer in
    trt.FillOperation.LINSPACE mode.

    Parameters:
        start : Union[Tensor, int]
            The starting point of the range.

        end : Union[Tensor, int]
            The end point of the range.

        dtype : str
            The type of the elements. See _str_to_trt_dtype_dict in _utils.py
            for a list of supported types and type names.

    Returns:
        The tensor produced by the fill layer. It is a 1D tensor containing
        `end-start` elements of type `dtype`.
    '''
    if isinstance(start, int):
        step = 1
        assert isinstance(end, int)
        assert isinstance(step, int)

        num = len(range(start, end, step))

        layer = default_trtnet().add_fill([num], trt.FillOperation.LINSPACE)
        layer.set_output_type(0, str_dtype_to_trt(dtype))
        layer.set_alpha(start)
        layer.set_beta(step)
        return _create_tensor(layer.get_output(0), layer)
    elif isinstance(start, Tensor):
        step = constant(int32_array([1]))
        assert isinstance(end, Tensor)
        assert isinstance(step, Tensor)

        num = end - start
        num = num.view([1])

        layer = default_trtnet().add_fill([0], trt.FillOperation.LINSPACE)
        layer.set_input(0, num.trt_tensor)  # rank = 1
        layer.set_input(1, start.trt_tensor)  # rank = 0
        layer.set_input(2, step.trt_tensor)  # rank = 1
        return _create_tensor(layer.get_output(0), layer)
    else:
        raise TypeError("%s is not supported" % type(start))

当输入两个int类型的时候,上面的layer.set_alpha(start)layer.set_beta(step)这俩已经废弃了,所以就会报错。

该bug已经和导师确认,nvidia内部id号:4285134