NVIDIA / trt-samples-for-hackathon-cn

Simple samples for TensorRT programming
Apache License 2.0
1.47k stars 337 forks source link

TensorRT-LLM子模型单测正常,搭建大模型后结果错误,但mark_output一下就正确 (Hackathon 2023) #97

Open EdVince opened 11 months ago

EdVince commented 11 months ago

Environment

CPU architecture: x86_64 GPU name: NVIDIA A10 TensorRT branch: 9.0.0 TensorRT LLM: 0.1.3 Cuda: 12.1.66 Cudnn: 8.9.0 Container: registry.cn-hangzhou.aliyuncs.com/trt-hackathon/trt-hackathon:final_v1 NVIDIA driver version: 525.105.17 OS: Ubuntu 22.04.3 LTS x86_64 Kernel: 5.15.0-73-generic

问题简要描述

实现了WhisperDecoderAttention类,支持self/cross和with/without kv_cache的Attention,该Attention单测正常,但用它搭建起WhisperDecoderLayer后,在self+with kv_cache下计算结果不正确,但如果将中间结果mark_output一下,计算就正确了。猜测是加了mark_output破坏了原始的图融合。

复现代码

模型代码:

import enum
import math
from dataclasses import dataclass
from typing import Optional

import torch
import numpy as np

import tensorrt as trt

from ..._common import default_net, precision
from ..._utils import str_dtype_to_trt
from ...functional import (Tensor, RaggedTensor, ACT2FN, 
                           unsqueeze, gelu, shape, gather, 
                           concat, view, permute, constant, 
                           split, matmul, softmax, cast,
                           identity)
from ...layers import Attention, LayerNorm, ColumnLinear, Conv2d
from ...module import Module, ModuleList
from ...parameter import Parameter
from ...layers.linear import ColumnLinear, RowLinear

def squeeze(input, axis):
    dims = input.ndim()
    input_shape = shape(input)
    out_shapes = []
    for i in range(dims):
        if i == axis:
            continue
        out_shapes.append(gather(input_shape, 0, i))
    out_shape = concat(out_shapes)
    input = view(input, out_shape)
    return input

class WhisperEncoderLayer(Module):

    def __init__(self, d_model=512, encoder_attention_heads=8, activation_function='gelu', encoder_ffn_dim=2048):
        super().__init__()
        self.embed_dim = d_model
        self.self_attn = Attention(self.embed_dim, encoder_attention_heads, 1)
        self.self_attn_layer_norm = LayerNorm(self.embed_dim)
        self.activation_fn = ACT2FN[activation_function]
        self.fc1 = ColumnLinear(self.embed_dim, encoder_ffn_dim)
        self.fc2 = ColumnLinear(encoder_ffn_dim, self.embed_dim)
        self.final_layer_norm = LayerNorm(self.embed_dim)

    def forward(self, hidden_states: RaggedTensor):

        input_lengths = hidden_states.row_lengths
        max_input_length = hidden_states.max_row_length
        hidden_states = hidden_states.data

        residual = hidden_states

        hidden_states = self.self_attn_layer_norm(hidden_states)
        hidden_states = self.self_attn(RaggedTensor.from_row_lengths(hidden_states, input_lengths, max_input_length))
        hidden_states = residual + hidden_states.data

        residual = hidden_states
        hidden_states = self.final_layer_norm(hidden_states)
        hidden_states = self.activation_fn(self.fc1(hidden_states))
        hidden_states = self.fc2(hidden_states)
        hidden_states = residual + hidden_states

        return hidden_states

class WhisperEncoder(Module):
    def __init__(self, 
                 d_model=512, num_mel_bins=80, max_source_positions=1500,
                 encoder_layers=6,
                 encoder_attention_heads=8, activation_function='gelu', encoder_ffn_dim=2048):
        super().__init__()

        embed_dim = d_model

        # 原本应该是Conv1d的,但trtllm还没实现,先用Conv2d替换
        self.conv1 = Conv2d(num_mel_bins, embed_dim, kernel_size=(1,3), padding=(0,1))
        self.conv2 = Conv2d(embed_dim, embed_dim, kernel_size=(1,3), stride=(1,2), padding=(0,1))

        self.embed_positions_weight = torch.zeros(1,max_source_positions,embed_dim).numpy()

        self.layers = ModuleList([WhisperEncoderLayer(d_model=d_model, 
                                                      encoder_attention_heads=encoder_attention_heads, 
                                                      activation_function=activation_function, 
                                                      encoder_ffn_dim=encoder_ffn_dim) for _ in range(encoder_layers)])

        self.layer_norm = LayerNorm(embed_dim)

    def forward(self, input_features: RaggedTensor):

        input_lengths = input_features.row_lengths
        max_input_length = input_features.max_row_length
        input_features = input_features.data

        input_features = unsqueeze(input_features,2)
        inputs_embeds = gelu(self.conv1(input_features))
        inputs_embeds = gelu(self.conv2(inputs_embeds))
        inputs_embeds = squeeze(inputs_embeds,2)
        inputs_embeds = permute(inputs_embeds,[0,2,1])

        hidden_states = inputs_embeds + constant(self.embed_positions_weight)

        for layer in self.layers:
            hidden_states = layer(RaggedTensor.from_row_lengths(hidden_states, input_lengths, max_input_length))

        hidden_states = self.layer_norm(hidden_states)

        return hidden_states

# class SimpleConvTRTLLMNet(Module):

#     def __init__(self):
#         super().__init__()
#         self.encoder = WhisperEncoder()

#     def forward(self, input_features: RaggedTensor):

#         hidden_states = self.encoder(input_features)

#         hidden_states.mark_output('output', str_dtype_to_trt('float32'))

#         return hidden_states

#     def prepare_inputs(self):

#         input_features_data = Tensor(name='data',
#                     dtype=trt.float32,
#                     shape=[1, 80, 3000])
#         input_features_length = Tensor(name='length',
#                     dtype=trt.float32,
#                     shape=[1])

#         input_features = RaggedTensor.from_row_lengths(input_features_data, input_features_length)

#         return (input_features)

class AttentionMaskType(enum.Enum):
    padding = 0
    causal = 1
    bidirectional = 2

class PositionEmbeddingType(enum.Enum):
    learned_absolute = enum.auto()
    rope = enum.auto()
    alibi = enum.auto()

@dataclass
class InflightBatchingParam:
    host_beam_widths: Tensor
    cache_indir_pointers: Tensor
    host_req_cache_max_seq_lengths: Tensor
    host_input_lengths: Tensor
    past_key_value_pointers: Tensor
    max_input_length: int
    max_beam_width: int
    kv_orig_quant_scale: Optional[Tensor] = None
    kv_quant_orig_scale: Optional[Tensor] = None
    use_int8_kv_cache: bool = False

    def __post_init__(self):
        assert self.max_input_length > 0, f"max_input_length must be positive, got {self.max_input_length}"
        assert self.max_beam_width > 0, f"max_beam_width must be positive, got {self.max_beam_width}"

class WhisperDecoderAttention(Module):

    def __init__(self,
                 hidden_size,
                 num_attention_heads,
                 max_position_embeddings=0,
                 num_layers=1,
                 apply_query_key_layer_scaling=False,
                 bias=True,
                 dtype=None,
                 position_embedding_type=PositionEmbeddingType.learned_absolute,
                 neox_rotary_style=False,
                 use_int8_kv_cache=False,
                 rotary_embedding_percentage=1.0,
                 tp_group=None,
                 tp_size=1,
                 multi_block_mode=False,
                 multi_query_mode=False):
        super().__init__()

        self.attention_head_size = hidden_size // num_attention_heads
        self.num_attention_heads = num_attention_heads // tp_size
        self.num_attention_kv_heads = 1 if multi_query_mode else self.num_attention_heads
        self.hidden_size = hidden_size // tp_size
        self.max_position_embeddings = max_position_embeddings

        self.num_layers = num_layers
        self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
        self.norm_factor = math.sqrt(self.attention_head_size)
        self.q_scaling = 1
        if self.apply_query_key_layer_scaling:
            self.norm_factor *= self.num_layers
            self.q_scaling *= self.num_layers

        self.position_embedding_type = position_embedding_type
        self.multi_block_mode = multi_block_mode
        self.multi_query_mode = multi_query_mode

        self.rotary_embedding_dim = 0
        self.neox_rotary_style = neox_rotary_style
        if self.position_embedding_type == PositionEmbeddingType.rope:
            self.rotary_embedding_dim = int(self.attention_head_size *
                                            rotary_embedding_percentage)
            # TODO: Once we add RotaryEmbedding outside GPTAttention plugin,
            #       we need to set it up here

        self.dtype = dtype

        self.use_int8_kv_cache = use_int8_kv_cache
        if self.use_int8_kv_cache:
            self.kv_orig_quant_scale = Parameter(shape=(1, ), dtype='float32')
            self.kv_quant_orig_scale = Parameter(shape=(1, ), dtype='float32')
        else:
            self.register_parameter('kv_orig_quant_scale', None)
            self.register_parameter('kv_quant_orig_scale', None)

        # Note: in multi_query_mode, only query heads are split between multiple GPUs,
        # while key/value head are not split as there is only one head per key/value.
        # The output feature size is therefore (h/tp + 2) * d, where h is num_heads,
        # d is head_size, and tp is tensor_parallel_size.
        # In ColumnLinear op, the output dim is calculated by (h + 2*tp) * d / tp,
        # which matches the desired output size (h/tp + 2) * d after splitting
        self.q_proj = ColumnLinear(hidden_size,
                                hidden_size,
                                bias=bias,
                                dtype=dtype,
                                tp_group=tp_group,
                                tp_size=tp_size)
        self.k_proj = ColumnLinear(hidden_size,
                                hidden_size,
                                bias=False,
                                dtype=dtype,
                                tp_group=tp_group,
                                tp_size=tp_size)
        self.v_proj = ColumnLinear(hidden_size,
                                hidden_size,
                                bias=bias,
                                dtype=dtype,
                                tp_group=tp_group,
                                tp_size=tp_size)
        self.dense = RowLinear(hidden_size,
                               hidden_size,
                               bias=bias,
                               dtype=dtype,
                               tp_group=tp_group,
                               tp_size=tp_size)

    def forward(self,
                hidden_states: RaggedTensor,
                key_value_states: Optional[RaggedTensor] = None,
                past_key_value: Optional[Tensor] = None
                ):

        input_lengths = hidden_states.row_lengths
        max_input_length = hidden_states.max_row_length
        hidden_states = hidden_states.data

        def transpose_for_scores(x):
            new_x_shape = concat([
                shape(x, 0),
                shape(x, 1), self.num_attention_heads, self.attention_head_size
            ])
            return x.view(new_x_shape).permute([0, 2, 1, 3])

        query_states = transpose_for_scores(self.q_proj(hidden_states))

        is_cross_attention = key_value_states is not None
        is_reuse = past_key_value is not None

        if is_cross_attention and is_reuse:
            dumpy_key_value_states = constant(np.zeros((512),dtype=np.float32))
            key_states = self.k_proj(dumpy_key_value_states)
            value_states = self.v_proj(dumpy_key_value_states)
            key_states, value_states = split(past_key_value,1,dim=0)
        elif is_cross_attention:
            key_states = transpose_for_scores(self.k_proj(key_value_states))
            value_states = transpose_for_scores(self.v_proj(key_value_states))
        elif is_reuse:
            # curr_key_states = transpose_for_scores(self.k_proj(hidden_states))
            # curr_value_states = transpose_for_scores(self.v_proj(hidden_states))
            # past_key_states, past_value_states = split(past_key_value,1,dim=0)
            # key_states = concat([past_key_states, curr_key_states], dim=2)
            # value_states = concat([past_value_states, curr_value_states], dim=2)

            past_key_states, past_value_states = split(past_key_value,1,dim=0)
            curr_key_states = transpose_for_scores(self.k_proj(hidden_states))
            curr_value_states = transpose_for_scores(self.v_proj(hidden_states))
            past_value_states.mark_output('hook0', str_dtype_to_trt('float32'))
            key_states = concat([past_key_states, curr_key_states], dim=2)
            value_states = concat([past_value_states, curr_value_states], dim=2)

        else:
            key_states = transpose_for_scores(self.k_proj(hidden_states))
            value_states = transpose_for_scores(self.v_proj(hidden_states))

        query = query_states
        key = key_states
        value = value_states

        past_key_value = concat([key, value], dim=0)

        key = key.permute([0, 1, 3, 2])

        with precision('float32'):
            attention_scores = matmul(cast(query, 'float32'), cast(key, 'float32'))
            attention_scores = attention_scores / self.norm_factor
            attention_probs = softmax(attention_scores, dim=-1)

        context = matmul(attention_probs, value).permute([0, 2, 1, 3])
        context = context.view(concat([shape(context, 0), shape(context, 1), self.hidden_size]))

        context = self.dense(context)

        context = RaggedTensor.from_row_lengths(context, input_lengths, max_input_length)

        return context, past_key_value

class WhisperDecoderLayer(Module):
    def __init__(self, d_model=512, decoder_attention_heads=8, activation_function='gelu', decoder_ffn_dim=2048):
        super().__init__()
        self.embed_dim = d_model

        self.self_attn = WhisperDecoderAttention(self.embed_dim,decoder_attention_heads)
        self.activation_fn = ACT2FN[activation_function]

        self.self_attn_layer_norm = LayerNorm(self.embed_dim)
        self.encoder_attn = WhisperDecoderAttention(self.embed_dim,decoder_attention_heads)
        self.encoder_attn_layer_norm = LayerNorm(self.embed_dim)
        self.fc1 = ColumnLinear(self.embed_dim, decoder_ffn_dim)
        self.fc2 = ColumnLinear(decoder_ffn_dim, self.embed_dim)
        self.final_layer_norm = LayerNorm(self.embed_dim)

    def forward(self,
        hidden_states: RaggedTensor,
        encoder_hidden_states: Optional[Tensor] = None,
        self_attn_past_key_value: Optional[Tensor] = None,
        cross_attn_past_key_value: Optional[Tensor] = None,
    ):

        input_lengths = hidden_states.row_lengths
        max_input_length = hidden_states.max_row_length
        hidden_states = hidden_states.data

        residual = hidden_states
        hidden_states = self.self_attn_layer_norm(hidden_states)

        # Self Attention
        hidden_states, present_key_value = self.self_attn(
            hidden_states=RaggedTensor.from_row_lengths(hidden_states, input_lengths, max_input_length),
            key_value_states=None,
            past_key_value=self_attn_past_key_value
        )
        hidden_states = residual + hidden_states.data

        residual = hidden_states
        hidden_states = self.encoder_attn_layer_norm(hidden_states)

        # Cross Attention
        hidden_states, cross_attn_present_key_value = self.encoder_attn(
            hidden_states=RaggedTensor.from_row_lengths(hidden_states, input_lengths, max_input_length),
            key_value_states=encoder_hidden_states,
            past_key_value=cross_attn_past_key_value,
        )
        hidden_states = residual + hidden_states.data

        # Fully Connected
        residual = hidden_states
        hidden_states = self.final_layer_norm(hidden_states)
        hidden_states = self.activation_fn(self.fc1(hidden_states))
        hidden_states = self.fc2(hidden_states)
        hidden_states = residual + hidden_states

        return hidden_states, present_key_value, cross_attn_present_key_value

class SimpleConvTRTLLMNet(Module):

    def __init__(self):
        super().__init__()
        self.layer = WhisperDecoderLayer()

    def forward(self, hidden_states: RaggedTensor, encoder_hidden_states: Tensor, self_attn_past_key_value: Tensor, cross_attn_past_key_value: Tensor):

        hidden_states, present_key_value, cross_attn_present_key_value = self.layer(hidden_states=hidden_states, 
                                    encoder_hidden_states=encoder_hidden_states,
                                    self_attn_past_key_value=self_attn_past_key_value,
                                    cross_attn_past_key_value=cross_attn_past_key_value)

        hidden_states.mark_output('output0', str_dtype_to_trt('float32'))
        present_key_value.mark_output('output1', str_dtype_to_trt('float32'))
        cross_attn_present_key_value.mark_output('output2', str_dtype_to_trt('float32'))

        return hidden_states

    def prepare_inputs(self):

        input_features_data = Tensor(name='data',
                    dtype=trt.float32,
                    shape=[1, 1, 512])
        input_features_length = Tensor(name='length',
                    dtype=trt.float32,
                    shape=[1])
        input_features = RaggedTensor.from_row_lengths(input_features_data, input_features_length)

        encoder_hidden_states = Tensor(name='encoder_hidden_states',
                    dtype=trt.float32,
                    shape=[1, 1500, 512])

        self_attn_past_key_value = Tensor(name='self_attn_past_key_value',
                    dtype=trt.float32,
                    shape=[2, 8, 23, 64])

        cross_attn_past_key_value = Tensor(name='cross_attn_past_key_value',
                    dtype=trt.float32,
                    shape=[2, 8, 1500, 64])

        return (input_features, encoder_hidden_states, self_attn_past_key_value, cross_attn_past_key_value)

if __name__ == '__main__':
    net = SimpleConvTRTLLMNet()

生成对比用的pytorch模型:

import math
import torch
import torch.nn as nn
from activations import ACT2FN
from typing import Optional, Tuple

class WhisperEncoderAttention(nn.Module):
    def __init__(
        self,
        embed_dim: int = 512,
        num_heads: int = 8,
        dropout: float = 0.0,
        bias: bool = True,
    ):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.dropout = dropout
        self.head_dim = embed_dim // num_heads
        self.scaling = self.head_dim**-0.5

        self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)

    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()

    def forward(
        self,
        hidden_states: torch.Tensor,
    ):

        # bsz=1, tgt_len=1500, _取决于模型大小
        bsz, tgt_len, _ = hidden_states.size()

        # get query proj
        query_states = self.q_proj(hidden_states) * self.scaling
        key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
        value_states = self._shape(self.v_proj(hidden_states), -1, bsz)

        proj_shape = (bsz * self.num_heads, -1, self.head_dim)
        query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
        key_states = key_states.reshape(*proj_shape)
        value_states = value_states.reshape(*proj_shape)

        attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))

        attn_weights = nn.functional.softmax(attn_weights, dim=-1)

        attn_output = torch.bmm(attn_weights, value_states)

        attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
        attn_output = attn_output.transpose(1, 2)
        attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)

        attn_output = self.out_proj(attn_output)

        return attn_output

class WhisperEncoderLayer(nn.Module):
    def __init__(self, d_model=512, encoder_attention_heads=8, activation_function='gelu', encoder_ffn_dim=2048):
        super().__init__()
        self.embed_dim = d_model
        self.self_attn = WhisperEncoderAttention(
            embed_dim=self.embed_dim,
            num_heads=encoder_attention_heads,
        )
        self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
        self.activation_fn = ACT2FN[activation_function]
        self.fc1 = nn.Linear(self.embed_dim, encoder_ffn_dim)
        self.fc2 = nn.Linear(encoder_ffn_dim, self.embed_dim)
        self.final_layer_norm = nn.LayerNorm(self.embed_dim)

    def forward(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:

        residual = hidden_states
        hidden_states = self.self_attn_layer_norm(hidden_states)
        hidden_states = self.self_attn(hidden_states=hidden_states)
        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.final_layer_norm(hidden_states)
        hidden_states = self.activation_fn(self.fc1(hidden_states))
        hidden_states = self.fc2(hidden_states)
        hidden_states = residual + hidden_states

        outputs = hidden_states

        return outputs

class WhisperEncoder(nn.Module):
    def __init__(self, 
                 d_model=512, num_mel_bins=80, max_source_positions=1500,
                 encoder_layers=6,
                 encoder_attention_heads=8, activation_function='gelu', encoder_ffn_dim=2048):
        super().__init__()

        embed_dim = d_model

        self.conv1 = nn.Conv1d(num_mel_bins, embed_dim, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1)

        self.embed_positions = nn.Embedding(max_source_positions, embed_dim)

        self.layers = nn.ModuleList([WhisperEncoderLayer(d_model=d_model, 
                                                         encoder_attention_heads=encoder_attention_heads, 
                                                         activation_function=activation_function, 
                                                         encoder_ffn_dim=encoder_ffn_dim) 
                    for _ in range(encoder_layers)])

        self.layer_norm = nn.LayerNorm(d_model)

    def forward(
        self,
        input_features, # (1,80,3000)
    ):

        inputs_embeds = nn.functional.gelu(self.conv1(input_features))
        inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))
        inputs_embeds = inputs_embeds.permute(0, 2, 1)

        embed_pos = self.embed_positions.weight
        hidden_states = inputs_embeds + embed_pos

        for encoder_layer in self.layers:
            hidden_states = encoder_layer(hidden_states)

        hidden_states = self.layer_norm(hidden_states)

        return hidden_states

class WhisperDecoderAttention(nn.Module):
    def __init__(
        self,
        embed_dim: int = 512,
        num_heads: int = 8,
        bias: bool = True,
    ):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        self.scaling = self.head_dim**-0.5

        self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)

    # Copied from transformers.models.bart.modeling_bart.BartAttention._shape with BART->whisper
    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()

    # Copied from transformers.models.bart.modeling_bart.BartAttention.forward with BART->whisper
    def forward(
        self,
        hidden_states: torch.Tensor,
        key_value_states: Optional[torch.Tensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
    ):

        is_cross_attention = key_value_states is not None

        bsz, tgt_len, _ = hidden_states.size()

        query_states = self.q_proj(hidden_states) * self.scaling

        if (
            is_cross_attention
            and past_key_value is not None
            and past_key_value[0].shape[2] == key_value_states.shape[1]
        ):
            # reuse k,v, cross_attentions
            key_states = past_key_value[0]
            value_states = past_key_value[1]
        elif is_cross_attention:
            # cross_attentions
            key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
            value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
        elif past_key_value is not None:
            # reuse k, v, self_attention
            key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
            value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
            key_states = torch.cat([past_key_value[0], key_states], dim=2)
            value_states = torch.cat([past_key_value[1], value_states], dim=2)
        else:
            # self_attention
            key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
            value_states = self._shape(self.v_proj(hidden_states), -1, bsz)

        past_key_value = (key_states, value_states)

        proj_shape = (bsz * self.num_heads, -1, self.head_dim)
        query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
        key_states = key_states.reshape(*proj_shape)
        value_states = value_states.reshape(*proj_shape)

        attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))

        attn_weights = nn.functional.softmax(attn_weights, dim=-1)

        attn_output = torch.bmm(attn_weights, value_states)

        attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
        attn_output = attn_output.transpose(1, 2)

        attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)

        attn_output = self.out_proj(attn_output)

        return attn_output, past_key_value

class WhisperDecoderLayer(nn.Module):
    def __init__(self, d_model = 512, decoder_attention_heads = 8, activation_function = 'gelu', decoder_ffn_dim = 2048):
        super().__init__()
        self.embed_dim = d_model

        self.self_attn = WhisperDecoderAttention(embed_dim=self.embed_dim,num_heads=decoder_attention_heads)
        self.activation_fn = ACT2FN[activation_function]

        self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
        self.encoder_attn = WhisperDecoderAttention(self.embed_dim,decoder_attention_heads,)
        self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
        self.fc1 = nn.Linear(self.embed_dim, decoder_ffn_dim)
        self.fc2 = nn.Linear(decoder_ffn_dim, self.embed_dim)
        self.final_layer_norm = nn.LayerNorm(self.embed_dim)

    def forward(
        self,
        hidden_states: torch.Tensor,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
    ) -> torch.Tensor:

        residual = hidden_states
        hidden_states = self.self_attn_layer_norm(hidden_states)

        # Self Attention
        # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
        self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
        # add present self-attn cache to positions 1,2 of present_key_value tuple
        hidden_states, present_key_value = self.self_attn(
            hidden_states=hidden_states,
            key_value_states=None,
            past_key_value=self_attn_past_key_value
        )
        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.encoder_attn_layer_norm(hidden_states)

        # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
        cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
        hidden_states, cross_attn_present_key_value = self.encoder_attn(
            hidden_states=hidden_states,
            key_value_states=encoder_hidden_states,
            past_key_value=cross_attn_past_key_value,
        )
        hidden_states = residual + hidden_states

        # add cross-attn to positions 3,4 of present_key_value tuple
        present_key_value = present_key_value + cross_attn_present_key_value

        # Fully Connected
        residual = hidden_states
        hidden_states = self.final_layer_norm(hidden_states)
        hidden_states = self.activation_fn(self.fc1(hidden_states))
        hidden_states = self.fc2(hidden_states)
        hidden_states = residual + hidden_states

        outputs = (hidden_states,)

        outputs += (present_key_value,)

        return outputs

class SimpleConvTorchNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer = WhisperDecoderLayer()

    def forward(self, hidden_states,encoder_hidden_states,past_key_value):
        output = self.layer(hidden_states,encoder_hidden_states,past_key_value)
        return output

if __name__ == '__main__':

    torch_net = SimpleConvTorchNet()
    torch.save(torch_net.state_dict(),'weight.pth')

    output = torch_net(torch.rand(1,1,512),torch.rand(1,1500,512),None)
    print(len(output),output[0].shape,[i.shape for i in output[1]])
    # 2 torch.Size([1, 1, 512]) [torch.Size([1, 8, 1, 64]), torch.Size([1, 8, 1, 64]), torch.Size([1, 8, 1500, 64]), torch.Size([1, 8, 1500, 64])]

    output = torch_net(torch.rand(1,1,512),torch.rand(1,1500,512),(torch.rand(1,8,23,64),torch.rand(1,8,23,64),torch.rand(1,8,1500,64),torch.rand(1,8,1500,64)))
    print(len(output),output[0].shape,[i.shape for i in output[1]])
    # 2 torch.Size([1, 1, 512]) [torch.Size([1, 8, 24, 64]), torch.Size([1, 8, 24, 64]), torch.Size([1, 8, 1500, 64]), torch.Size([1, 8, 1500, 64])]

    output = torch_net(torch.rand(1,1,512),torch.rand(1,1500,512),(torch.rand(1,8,1,64),torch.rand(1,8,1,64),torch.rand(1,8,1500,64),torch.rand(1,8,1500,64)))
    print(len(output),output[0].shape,[i.shape for i in output[1]])
    # 2 torch.Size([1, 1, 512]) [torch.Size([1, 8, 2, 64]), torch.Size([1, 8, 2, 64]), torch.Size([1, 8, 1500, 64]), torch.Size([1, 8, 1500, 64])]

构建engine:

import time
import torch

import tensorrt_llm
from tensorrt_llm.builder import Builder
from tensorrt_llm.logger import logger
from tensorrt_llm.network import net_guard

def serialize_engine(engine, path):
    logger.info(f'Serializing engine to {path}...')
    tik = time.time()
    with open(path, 'wb') as f:
        f.write(bytearray(engine))
    tok = time.time()
    t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
    logger.info(f'Engine serialized. Total time: {t}')

if __name__ == '__main__':

    logger.set_level('info')
    torch.cuda.set_device(0)
    tensorrt_llm.logger.set_level('info')

    # create builder
    builder = Builder()
    builder_config = builder.create_builder_config(
        name='SimpleWhisper',
        precision='float32',
        timing_cache='model.cache',
        tensor_parallel=1,
        parallel_build=False,
        int8=False,
        opt_level=None,
        )

    # create tensort-llm model
    tensorrt_llm_test = tensorrt_llm.models.SimpleConvTRTLLMNet()

    ckpt = torch.load('weight.pth',map_location='cpu')

    print(ckpt.keys())

    tensorrt_llm_test.layer.self_attn.q_proj.weight.value = ckpt['layer.self_attn.q_proj.weight'].numpy()
    tensorrt_llm_test.layer.self_attn.q_proj.bias.value = ckpt['layer.self_attn.q_proj.bias'].numpy()
    tensorrt_llm_test.layer.self_attn.k_proj.weight.value = ckpt['layer.self_attn.k_proj.weight'].numpy()
    tensorrt_llm_test.layer.self_attn.v_proj.weight.value = ckpt['layer.self_attn.v_proj.weight'].numpy()
    tensorrt_llm_test.layer.self_attn.v_proj.bias.value = ckpt['layer.self_attn.v_proj.bias'].numpy()
    tensorrt_llm_test.layer.self_attn.dense.weight.value = ckpt['layer.self_attn.out_proj.weight'].numpy()
    tensorrt_llm_test.layer.self_attn.dense.bias.value = ckpt['layer.self_attn.out_proj.bias'].numpy()
    tensorrt_llm_test.layer.self_attn_layer_norm.weight.value = ckpt['layer.self_attn_layer_norm.weight'].numpy()
    tensorrt_llm_test.layer.self_attn_layer_norm.bias.value = ckpt['layer.self_attn_layer_norm.bias'].numpy()
    tensorrt_llm_test.layer.encoder_attn.q_proj.weight.value = ckpt['layer.encoder_attn.q_proj.weight'].numpy()
    tensorrt_llm_test.layer.encoder_attn.q_proj.bias.value = ckpt['layer.encoder_attn.q_proj.bias'].numpy()
    tensorrt_llm_test.layer.encoder_attn.k_proj.weight.value = ckpt['layer.encoder_attn.k_proj.weight'].numpy()
    tensorrt_llm_test.layer.encoder_attn.v_proj.weight.value = ckpt['layer.encoder_attn.v_proj.weight'].numpy()
    tensorrt_llm_test.layer.encoder_attn.v_proj.bias.value = ckpt['layer.encoder_attn.v_proj.bias'].numpy()
    tensorrt_llm_test.layer.encoder_attn.dense.weight.value = ckpt['layer.encoder_attn.out_proj.weight'].numpy()
    tensorrt_llm_test.layer.encoder_attn.dense.bias.value = ckpt['layer.encoder_attn.out_proj.bias'].numpy()
    tensorrt_llm_test.layer.encoder_attn_layer_norm.weight.value = ckpt['layer.encoder_attn_layer_norm.weight'].numpy()
    tensorrt_llm_test.layer.encoder_attn_layer_norm.bias.value = ckpt['layer.encoder_attn_layer_norm.bias'].numpy()
    tensorrt_llm_test.layer.fc1.weight.value = ckpt['layer.fc1.weight'].numpy()
    tensorrt_llm_test.layer.fc1.bias.value = ckpt['layer.fc1.bias'].numpy()
    tensorrt_llm_test.layer.fc2.weight.value = ckpt['layer.fc2.weight'].numpy()
    tensorrt_llm_test.layer.fc2.bias.value = ckpt['layer.fc2.bias'].numpy()
    tensorrt_llm_test.layer.final_layer_norm.weight.value = ckpt['layer.final_layer_norm.weight'].numpy()
    tensorrt_llm_test.layer.final_layer_norm.bias.value = ckpt['layer.final_layer_norm.bias'].numpy()

    network = builder.create_network()
    network.trt_network.name = 'SimpleWhisper'

    with net_guard(network):

        network.set_named_parameters(tensorrt_llm_test.named_parameters())

        inputs = tensorrt_llm_test.prepare_inputs()

        tensorrt_llm_test(*inputs)

    engine = builder.build_engine(network, builder_config)

    assert engine is not None, f'Failed to build engine'

    serialize_engine(engine, 'simplewhisper.engine')

运行对比结果:

import argparse
import csv
import json
from pathlib import Path
import contextlib
import numpy as np
import torch

import tensorrt as trt
import tensorrt_llm
from tensorrt_llm.runtime import Session, TensorInfo

from create import SimpleConvTorchNet

@contextlib.contextmanager
def _scoped_stream():
    '''Create a scoped cuda stream, and synchronize it when the context is destroyed
    '''
    #TODO: delete torch, use cuda native python bindings
    import torch
    stream = torch.cuda.current_stream()
    try:
        # return a handle, trt and other lib does not recognize torch.cuda.Stream
        yield stream.cuda_stream
    finally:
        stream.synchronize()

if __name__ == '__main__':

    tensorrt_llm.logger.set_level('info')
    runtime_rank = tensorrt_llm.mpi_rank()
    runtime_mapping = tensorrt_llm.Mapping(1, 0)
    torch.cuda.set_device(0)

    # load engine
    with open('simplewhisper.engine', 'rb') as f:
        engine_buffer = f.read()
    session = Session.from_serialized_engine(engine_buffer)

    # inference output shape
    inputs_shape = [
        TensorInfo('data',trt.float32,(1,1,512)),
        TensorInfo('length',trt.float32,(1,)),
        TensorInfo('encoder_hidden_states',trt.float32,(1,1500,512)),
        TensorInfo('self_attn_past_key_value',trt.float32,(2,8,23,64)),
        TensorInfo('cross_attn_past_key_value',trt.float32,(2,8,1500,64)),
    ]
    outputs_shape = session.infer_shapes(inputs_shape)

    # malloc buffer
    inputs = {
        'data': torch.rand(1,1,512).cuda(),
        'length': torch.Tensor([1.0]).cuda(),
        'encoder_hidden_states': torch.rand(1,1500,512).cuda(),
        'self_attn_past_key_value': torch.rand(2,8,23,64).cuda(),
        'cross_attn_past_key_value': torch.rand(2,8,1500,64).cuda(),
    }
    outputs = {}
    for output in outputs_shape:
        outputs[output.name] = torch.zeros(*output.shape).cuda()

    # execute
    with _scoped_stream() as stream:
        ok = session.run(inputs, outputs, stream)
    torch.cuda.synchronize()
    trtllm_out = outputs['output0']
    trtllm_skv = outputs['output1']
    trtllm_ckv = outputs['output2']
    # print(trtllm_out.shape,trtllm_skv.shape,trtllm_ckv.shape)

    torch_net = SimpleConvTorchNet()
    torch_net.load_state_dict(torch.load('weight.pth',map_location='cpu'))
    torch_net.cuda()
    with torch.inference_mode():
        torch_out, (torch_sk, torch_sv, torch_ck, torch_cv) = torch_net(inputs['data'],inputs['encoder_hidden_states'],
                                                                        (inputs['self_attn_past_key_value'][0:1],inputs['self_attn_past_key_value'][1:2],
                                                                         inputs['cross_attn_past_key_value'][0:1],inputs['cross_attn_past_key_value'][1:2]))
    torch_skv = torch.cat([torch_sk,torch_sv],dim=0)
    torch_ckv = torch.cat([torch_ck,torch_cv],dim=0)

    a = trtllm_skv[0].cpu().numpy()
    b = torch_skv[0].cpu().numpy()
    diff = np.abs(a-b)
    print(a.shape,a.min(),a.mean(),a.max(),a.var())
    print(b.shape,b.min(),b.mean(),b.max(),b.var())
    print(diff.shape,diff.min(),diff.mean(),diff.max(),diff.var())

    a = trtllm_skv[1].cpu().numpy()
    b = torch_skv[1].cpu().numpy()
    diff = np.abs(a-b)
    print(a.shape,a.min(),a.mean(),a.max(),a.var())
    print(b.shape,b.min(),b.mean(),b.max(),b.var())
    print(diff.shape,diff.min(),diff.mean(),diff.max(),diff.var())

通过启用/关闭注释掉模型代码中第299行的mark_output可以对比结果。具体依赖较多,建议直接参考开发commit的代码

运行create.py -> build.py -> run.py就可以得到输出。