Closed Tlntin closed 10 months ago
Not a bug, to be confirmed @Tlntin
Root causes of the issue observed:
Here's the modified code passing all tests.
import unittest
import numpy as np
import torch
from polygraphy.backend.trt import EngineFromNetwork, TrtRunner
import tensorrt_llm
from tensorrt_llm import Tensor
import math
import tensorrt as trt
import numpy as np
from tensorrt_llm.functional import (
Tensor, shape, concat, constant, arange, outer, unary,
partial, expand, elementwise_binary, shape, pow
)
log = partial(unary, op=trt.UnaryOperation.LOG)
ceil = partial(unary, op=trt.UnaryOperation.CEIL)
div = partial(elementwise_binary, op=trt.ElementWiseOperation.DIV)
class TestFunctional(unittest.TestCase):
def setUp(self):
tensorrt_llm.logger.set_level('error')
def test_case(self):
dtype = 'float32'
batch_size = 1
input_seq_len = 3727
per_head_dim = 128
mha_seq_len = 2048
base = 10000.0
input_tensor = torch.rand([batch_size, input_seq_len], dtype=torch.float32)
builder = tensorrt_llm.Builder()
net = builder.create_network()
with tensorrt_llm.net_guard(net):
network = tensorrt_llm.default_trtnet()
trt_input = Tensor(name='input_tensor',
shape=input_tensor.shape,
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
input_len = shape(trt_input, 1)
context_value = log(
input_len.cast(trt.float32) / constant(np.array(mha_seq_len, dtype=np.float32))
) / constant(np.array(math.log(2.0), dtype=np.float32)) + constant(np.array(1.0, dtype=np.float32))
ntk_alpha = pow(constant(np.array(2, dtype=np.float32)), ceil(context_value)) - 1.0
# ntk_alpha = f_max(ntk_alpha, 1.0)
# ntk_alpha = constant(np.array(1., dtype=np.float32))
ntk_alpha = pow(ntk_alpha, (per_head_dim / (per_head_dim - 2)))
ntk_alpha_output = ntk_alpha.trt_tensor
ntk_alpha_output.name = 'ntk_alpha'
network.mark_output(ntk_alpha_output)
trt_base = constant(np.array(base, dtype=np.float32)) * ntk_alpha
trt_base_output = trt_base.trt_tensor
trt_base_output.name = 'base'
network.mark_output(trt_base_output)
temp1 = constant(np.arange(0, per_head_dim, 2, dtype=np.float32) / per_head_dim)
output_temp1 = temp1.trt_tensor
output_temp1.name = 'temp1'
network.mark_output(output_temp1)
temp2 = pow(trt_base, temp1)
output_temp2 = temp2.trt_tensor
output_temp2.name = 'temp2'
network.mark_output(output_temp2)
# inv_freq = constant(np.ones([per_head_dim // 2], dtype=np.float32)) / temp2
inv_freq = div(
constant(np.array(1, dtype=np.float32)),
temp2
)
output1 = inv_freq.trt_tensor
output1.name = 'inv_freq'
network.mark_output(output1)
trt_seq = arange(constant(np.array(0, dtype=np.int32)), input_len * 2, dtype="int32")
trt_seq_output = trt_seq.trt_tensor
trt_seq_output.name = 'seq'
network.mark_output(trt_seq_output)
trt_freqs = outer(trt_seq.cast(trt.float32), inv_freq.cast(trt.float32))
trt_emb = concat([trt_freqs, trt_freqs], dim=1)
# emb = rearrange(emb, "n d -> 1 n 1 d")
trt_emb = trt_emb.view(concat([1, input_len * 2, 1, per_head_dim]))
trt_emb = expand(
trt_emb, concat([batch_size, input_len * 2, 1, per_head_dim])
)
output2 = trt_freqs.trt_tensor
output2.name = 'freqs'
network.mark_output(output2)
# upper for old
# lower for pure pytorch for fp32 consistency(code in above used fp64 by python)
# context_value = math.log(input_seq_len / mha_seq_len, 2) + 1
context_value = torch.log(torch.Tensor([input_seq_len * 1. / mha_seq_len]).cuda()) / torch.log(torch.Tensor([2.]).cuda()) + 1
#ntk_alpha = 2 ** math.ceil(context_value) - 1
ntk_alpha = torch.Tensor([2]).cuda() ** torch.ceil(context_value) - 1
# ntk_alpha = max(ntk_alpha, 1)
ntk_alpha = ntk_alpha ** (per_head_dim / (per_head_dim - 2))
base = torch.Tensor([base]).cuda()
base = base * ntk_alpha
temp1 = (torch.arange(0, per_head_dim, 2).float() / per_head_dim).cuda()
temp2 = torch.pow(base, temp1) # base ** temp1
inv_freq = 1.0 / temp2
seq = torch.arange(0, input_seq_len * 2).int().cuda()
freqs = torch.outer(seq.type_as(inv_freq), inv_freq)
# for new
build_engine = EngineFromNetwork((builder.trt_builder, net.trt_network))
with TrtRunner(build_engine) as runner:
outputs = runner.infer(feed_dict={"input_tensor": input_tensor.numpy()})
# import pdb; pdb.set_trace()
np.testing.assert_allclose(ntk_alpha.cpu().numpy(), outputs['ntk_alpha'], rtol=0, atol=0)
np.testing.assert_allclose(base.cpu().numpy(), outputs['base'], rtol=0, atol=0)
np.testing.assert_allclose(temp1.cpu().numpy(), outputs['temp1'], rtol=0, atol=0)
np.testing.assert_allclose(temp2.cpu().numpy(), outputs['temp2'], rtol=0, atol=0)
np.testing.assert_allclose(seq.cpu().numpy(), outputs['seq'], rtol=1e-9, atol=1e-9)
np.testing.assert_allclose(inv_freq.cpu().numpy(), outputs['inv_freq'], rtol=1e-9, atol=1e-9)
np.testing.assert_allclose(freqs.cpu().numpy(), outputs['freqs'], rtol=1e-9, atol=1e-9)
if __name__ == "__main__":
unittest.main()
even setting rtol and atol all to 0 should also pass all tests(at least on my laptop w/ A3000 gpu) as its bitwise same results now.
Lessons learned:
Environment
If applicable, please include the following: CPU architecture: x86_64 GPU name: NVIDIA A10 TensorRT branch: 9.0.0 TensorRT LLM: 0.1.3 Cuda: 12.1.66 Cudnn: 8.9.0 Container: registry.cn-hangzhou.aliyuncs.com/trt-hackathon/trt-hackathon:final_v1 NVIDIA driver version: 525.105.17 OS: Ubuntu 22.04.3 LTS x86_64 Kernel: 5.15.0-73-generic
相关代码:
预期结果:通过测试 实际结果:
现象描述:上述计算中,temp1, temp2误差均为0,但是经过div节点后,误差变成e-9左右,再经过一个outer计算后,误差扩大为e-4,这个误差已经远远大于float32的表示范围了,所以不应该出现这种情况。