axinc-ai / ailia-models

The collection of pre-trained, state-of-the-art AI models for ailia SDK
2.04k stars 325 forks source link

ADD Florence-2 #1541

Closed itsmeterada closed 1 month ago

itsmeterada commented 2 months ago

https://huggingface.co/microsoft/Florence-2-base https://huggingface.co/microsoft/Florence-2-base-ft https://huggingface.co/microsoft/Florence-2-large https://huggingface.co/microsoft/Florence-2-large-ft

kyakuno commented 2 months ago

@ooe1123 GPT-SoVITS2の後に、こちらをお願いできると嬉しいです。

kyakuno commented 2 months ago

軽量なVLMモデルです。

kyakuno commented 2 months ago

MITライセンス。

kyakuno commented 2 months ago

Finetuning

公式(Azure) : https://qiita.com/yo-naka/items/ba552a54b856f9318a7c HuggingFace : https://huggingface.co/blog/finetune-florence2 サードパーティ:https://github.com/andimarafioti/florence2-finetuning

kyakuno commented 2 months ago

baseのfloatで460MBぐらいとのこと。

kyakuno commented 2 months ago

tokenizerはBartTokenizerを使っている。

kyakuno commented 2 months ago

Roberta Tokenizerに近い。 https://github.com/huggingface/transformers/blob/main/src/transformers/models/bart/tokenization_bart.py

ooe1123 commented 1 month ago

embeddings.onnx

〇 huggingface/modules/transformers_modules/microsoft/Florence-2-large/xxx/modeling_florence2.py

class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
    def generate(
        self,
        input_ids, 
        inputs_embeds=None,
        pixel_values=None,
        **kwargs
        ):

        if inputs_embeds is None:
            ...

class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
    def generate(
        self,
        input_ids, 
        inputs_embeds=None,
        pixel_values=None,
        **kwargs
        ):

        # EXP
        if 1:
            print("------>")
            from torch.autograd import Variable
            x = Variable(input_ids)
            torch.onnx.export(
                self.get_input_embeddings(), x, 'embeddings.onnx',
                input_names=["input_ids"],
                output_names=["inputs_embeds"],
                dynamic_axes={
                    "input_ids": {0: "batch_size", 1: "decoder_sequence_length"}, 
                    "inputs_embeds": {0: "batch_size", 1: "decoder_sequence_length"}
                },
                verbose=False, opset_version=17
            )
            print("<------")
            exit()

        if inputs_embeds is None:
            ...

feature.onnx

〇 huggingface/modules/transformers_modules/microsoft/Florence-2-large/xxx/modeling_florence2.py

class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
    def generate(
        self,
        input_ids, 
        inputs_embeds=None,
        pixel_values=None,
        **kwargs
        ):

        if inputs_embeds is None:
            ...

class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
    def generate(
        self,
        input_ids, 
        inputs_embeds=None,
        pixel_values=None,
        **kwargs
        ):

        # EXP
        if 1:
            class Exp(torch.nn.Module):
                def __init__(self, model):
                    super().__init__()
                    self.model = model

                def forward(self, pixel_values):
                    image_features = self.model._encode_image(pixel_values)
                    return image_features

            with torch.no_grad():
                print("------>")
                from torch.autograd import Variable
                model = Exp(self)
                x = Variable(pixel_values)
                torch.onnx.export(
                    model, x, 'feature.onnx',
                    input_names=["pixel_values"],
                    output_names=["image_features"],
                    dynamic_axes={
                        "pixel_values": {0: "batch_size"},
                        "image_features": {0: "batch_size"},
                    },
                    verbose=False, opset_version=17
                )
                print("<------")
            exit()

        if inputs_embeds is None:
            ...
ooe1123 commented 1 month ago

encoder.onnx

〇 /usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py

class GenerationMixin:
    ...
    def _prepare_encoder_decoder_kwargs_for_generation(
        self,
        ...
    ) -> Dict[str, Any]:
        ...
        model_kwargs["encoder_outputs"]: ModelOutput = encoder(**encoder_kwargs)

        return model_kwargs

class GenerationMixin:
    ...
    def _prepare_encoder_decoder_kwargs_for_generation(
        self,
        ...
    ) -> Dict[str, Any]:
        ...
        model_kwargs["encoder_outputs"]: ModelOutput = encoder(**encoder_kwargs)

        # EXP
        if 1:
            class Exp(nn.Module):
                def __init__(self, encoder):
                    super(Exp, self).__init__()
                    self.encoder = encoder
                def forward(
                        self, inputs_embeds, attention_mask
                    ):
                    encoder_kwargs = {
                        "inputs_embeds": inputs_embeds,
                        "attention_mask": attention_mask,
                        "output_attentions": False,
                        "output_hidden_states": False,
                        "return_dict": True,
                    }
                    outputs = self.encoder(
                        **encoder_kwargs,
                    )
                    return outputs.last_hidden_state

            model = Exp(encoder)
            print("------>")
            from torch.autograd import Variable
            xx = (
                Variable(encoder_kwargs["inputs_embeds"]),
                Variable(encoder_kwargs["attention_mask"]),
            )
            torch.onnx.export(
                model, xx, 'encoder_model.onnx',
                input_names=[
                    'inputs_embeds', 'attention_mask', 
                ],
                output_names=[
                    'last_hidden_state',
                ],
                dynamic_axes={
                    'inputs_embeds': {0: 'batch_size', 1: 'encoder_sequence_length / 2'},
                    'attention_mask': {0: 'batch_size', 1: 'encoder_sequence_length / 2'},
                    'last_hidden_state': {0: 'batch_size', 1: 'encoder_sequence_length / 2'},
                },
                verbose=False, opset_version=17
            )
            print("<------")
            exit(0)

        return model_kwargs
ooe1123 commented 1 month ago

decoder.onnx

〇 huggingface/modules/transformers_modules/microsoft/Florence-2-base/xxx/modeling_florence2.py

class Florence2SdpaAttention(Florence2Attention):
    def forward(
        self,
        ...
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        ...
        if (
            is_cross_attention
            and past_key_value is not None
            and past_key_value[0].shape[2] == key_value_states.shape[1]
        ):
            # reuse k,v, cross_attentions
            key_states = past_key_value[0]
            value_states = past_key_value[1]
        elif is_cross_attention:
            # cross_attentions
            key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
            value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
        elif past_key_value is not None:
            # reuse k, v, self_attention
            key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
            value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
            key_states = torch.cat([past_key_value[0], key_states], dim=2)
            value_states = torch.cat([past_key_value[1], value_states], dim=2)
        else:
            # self_attention
            key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
            value_states = self._shape(self.v_proj(hidden_states), -1, bsz)

from transformers.utils.export import ctrl_flg

class Florence2SdpaAttention(Florence2Attention):
    def forward(
        self,
        ...
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        ...
        if ctrl_flg[0] is False:
            if (
                is_cross_attention
                and past_key_value is not None
                and past_key_value[0].shape[2] == key_value_states.shape[1]
            ):
                # reuse k,v, cross_attentions
                key_states = past_key_value[0]
                value_states = past_key_value[1]
            elif is_cross_attention:
                # cross_attentions
                key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
                value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
            elif past_key_value is not None:
                # reuse k, v, self_attention
                key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
                value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
                key_states = torch.cat([past_key_value[0], key_states], dim=2)
                value_states = torch.cat([past_key_value[1], value_states], dim=2)
            else:
                # self_attention
                key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
                value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
        else:
            if is_cross_attention:
                key_states = torch.cat([past_key_value[0], self._shape(self.k_proj(key_value_states), -1, bsz)], dim=2)
                value_states = torch.cat([past_key_value[1], self._shape(self.v_proj(key_value_states), -1, bsz)], dim=2)
                key_states = key_states[:,:,:key_value_states.shape[1],:]
                value_states = value_states[:,:,:key_value_states.shape[1],:]
            else:
                key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
                value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
                key_states = torch.cat([past_key_value[0], key_states], dim=2)
                value_states = torch.cat([past_key_value[1], value_states], dim=2)

〇 /usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py

class GenerationMixin:
    ...
    def _beam_search(
        self,
        ...
    ) -> Union[GenerateBeamOutput, torch.LongTensor]:
        ...
        while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
            ...
            if sequential:
                ...
            else:  # Unchanged original behavior
                outputs = self(**model_inputs, return_dict=True)

class GenerationMixin:
    ...
    def _beam_search(
        self,
        ...
    ) -> Union[GenerateBeamOutput, torch.LongTensor]:
        ...
        # Add
        from transformers.utils.export import ctrl_flg
        ctrl_flg[0] = True

        while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
            ...
            # Add
            if ctrl_flg[0] and model_inputs["past_key_values"] is None:
                b = model_inputs["encoder_outputs"][0].size(0)
                d = model_inputs["encoder_outputs"][0].device
                # model_inputs["past_key_values"] = [
                #     [
                #         torch.zeros(b, 12, 0, 64, dtype=torch.float16).to(d),
                #     ] * 4
                # ] * 6
                model_inputs["past_key_values"] = [
                    [
                        torch.zeros(b, 16, 0, 64, dtype=torch.float16).to(d),
                    ] * 4
                ] * 12

            # EXP
            # if 0:
            if 1 and 0 < model_inputs["past_key_values"][0][0].size(2):
                class Exp(nn.Module):
                    def __init__(self, net):
                        super(Exp, self).__init__()
                        self.net = net
                    def forward(
                            self, decoder_input_ids, encoder_hidden_states,
                            past_key_values_0_decoder_key, past_key_values_0_decoder_value, past_key_values_0_encoder_key, past_key_values_0_encoder_value, 
                            past_key_values_1_decoder_key, past_key_values_1_decoder_value, past_key_values_1_encoder_key, past_key_values_1_encoder_value,
                            past_key_values_2_decoder_key, past_key_values_2_decoder_value, past_key_values_2_encoder_key, past_key_values_2_encoder_value,
                            past_key_values_3_decoder_key, past_key_values_3_decoder_value, past_key_values_3_encoder_key, past_key_values_3_encoder_value,
                            past_key_values_4_decoder_key, past_key_values_4_decoder_value, past_key_values_4_encoder_key, past_key_values_4_encoder_value,
                            past_key_values_5_decoder_key, past_key_values_5_decoder_value, past_key_values_5_encoder_key, past_key_values_5_encoder_value,
                            ###
                            past_key_values_6_decoder_key, past_key_values_6_decoder_value, past_key_values_6_encoder_key, past_key_values_6_encoder_value, 
                            past_key_values_7_decoder_key, past_key_values_7_decoder_value, past_key_values_7_encoder_key, past_key_values_7_encoder_value,
                            past_key_values_8_decoder_key, past_key_values_8_decoder_value, past_key_values_8_encoder_key, past_key_values_8_encoder_value,
                            past_key_values_9_decoder_key, past_key_values_9_decoder_value, past_key_values_9_encoder_key, past_key_values_9_encoder_value,
                            past_key_values_10_decoder_key, past_key_values_10_decoder_value, past_key_values_10_encoder_key, past_key_values_10_encoder_value,
                            past_key_values_11_decoder_key, past_key_values_11_decoder_value, past_key_values_11_encoder_key, past_key_values_11_encoder_value,
                        ):
                        model_inputs = {
                            "decoder_input_ids": decoder_input_ids,
                            "encoder_outputs": [encoder_hidden_states],
                            "past_key_values": [
                                [
                                    past_key_values_0_decoder_key,
                                    past_key_values_0_decoder_value,
                                    past_key_values_0_encoder_key,
                                    past_key_values_0_encoder_value,
                                ],
                                [
                                    past_key_values_1_decoder_key,
                                    past_key_values_1_decoder_value,
                                    past_key_values_1_encoder_key,
                                    past_key_values_1_encoder_value,
                                ],
                                [
                                    past_key_values_2_decoder_key,
                                    past_key_values_2_decoder_value,
                                    past_key_values_2_encoder_key,
                                    past_key_values_2_encoder_value,
                                ],
                                [
                                    past_key_values_3_decoder_key,
                                    past_key_values_3_decoder_value,
                                    past_key_values_3_encoder_key,
                                    past_key_values_3_encoder_value,
                                ],
                                [
                                    past_key_values_4_decoder_key,
                                    past_key_values_4_decoder_value,
                                    past_key_values_4_encoder_key,
                                    past_key_values_4_encoder_value,
                                ],
                                [
                                    past_key_values_5_decoder_key,
                                    past_key_values_5_decoder_value,
                                    past_key_values_5_encoder_key,
                                    past_key_values_5_encoder_value,
                                ],
                                [
                                    past_key_values_6_decoder_key,
                                    past_key_values_6_decoder_value,
                                    past_key_values_6_encoder_key,
                                    past_key_values_6_encoder_value,
                                ],
                                [
                                    past_key_values_7_decoder_key,
                                    past_key_values_7_decoder_value,
                                    past_key_values_7_encoder_key,
                                    past_key_values_7_encoder_value,
                                ],
                                [
                                    past_key_values_8_decoder_key,
                                    past_key_values_8_decoder_value,
                                    past_key_values_8_encoder_key,
                                    past_key_values_8_encoder_value,
                                ],
                                [
                                    past_key_values_9_decoder_key,
                                    past_key_values_9_decoder_value,
                                    past_key_values_9_encoder_key,
                                    past_key_values_9_encoder_value,
                                ],
                                [
                                    past_key_values_10_decoder_key,
                                    past_key_values_10_decoder_value,
                                    past_key_values_10_encoder_key,
                                    past_key_values_10_encoder_value,
                                ],
                                [
                                    past_key_values_11_decoder_key,
                                    past_key_values_11_decoder_value,
                                    past_key_values_11_encoder_key,
                                    past_key_values_11_encoder_value,
                                ],
                            ],
                        }
                        outputs = self.net(
                            **model_inputs,
                            use_cache=True,
                            return_dict=True,
                        )
                        return (
                            outputs["logits"],
                            outputs["past_key_values"][0][0],
                            outputs["past_key_values"][0][1],
                            outputs["past_key_values"][0][2],
                            outputs["past_key_values"][0][3],
                            outputs["past_key_values"][1][0],
                            outputs["past_key_values"][1][1],
                            outputs["past_key_values"][1][2],
                            outputs["past_key_values"][1][3],
                            outputs["past_key_values"][2][0],
                            outputs["past_key_values"][2][1],
                            outputs["past_key_values"][2][2],
                            outputs["past_key_values"][2][3],
                            outputs["past_key_values"][3][0],
                            outputs["past_key_values"][3][1],
                            outputs["past_key_values"][3][2],
                            outputs["past_key_values"][3][3],
                            outputs["past_key_values"][4][0],
                            outputs["past_key_values"][4][1],
                            outputs["past_key_values"][4][2],
                            outputs["past_key_values"][4][3],
                            outputs["past_key_values"][5][0],
                            outputs["past_key_values"][5][1],
                            outputs["past_key_values"][5][2],
                            outputs["past_key_values"][5][3],
                            ###
                            outputs["past_key_values"][6][0],
                            outputs["past_key_values"][6][1],
                            outputs["past_key_values"][6][2],
                            outputs["past_key_values"][6][3],
                            outputs["past_key_values"][7][0],
                            outputs["past_key_values"][7][1],
                            outputs["past_key_values"][7][2],
                            outputs["past_key_values"][7][3],
                            outputs["past_key_values"][8][0],
                            outputs["past_key_values"][8][1],
                            outputs["past_key_values"][8][2],
                            outputs["past_key_values"][8][3],
                            outputs["past_key_values"][9][0],
                            outputs["past_key_values"][9][1],
                            outputs["past_key_values"][9][2],
                            outputs["past_key_values"][9][3],
                            outputs["past_key_values"][10][0],
                            outputs["past_key_values"][10][1],
                            outputs["past_key_values"][10][2],
                            outputs["past_key_values"][10][3],
                            outputs["past_key_values"][11][0],
                            outputs["past_key_values"][11][1],
                            outputs["past_key_values"][11][2],
                            outputs["past_key_values"][11][3],
                        )

                model = Exp(self)
                print("------>")
                from torch.autograd import Variable
                xx = (
                    Variable(model_inputs["decoder_input_ids"]),
                    Variable(model_inputs["encoder_outputs"].last_hidden_state),
                    Variable(model_inputs["past_key_values"][0][0]),
                    Variable(model_inputs["past_key_values"][0][1]),
                    Variable(model_inputs["past_key_values"][0][2]),
                    Variable(model_inputs["past_key_values"][0][3]),
                    Variable(model_inputs["past_key_values"][1][0]),
                    Variable(model_inputs["past_key_values"][1][1]),
                    Variable(model_inputs["past_key_values"][1][2]),
                    Variable(model_inputs["past_key_values"][1][3]),
                    Variable(model_inputs["past_key_values"][2][0]),
                    Variable(model_inputs["past_key_values"][2][1]),
                    Variable(model_inputs["past_key_values"][2][2]),
                    Variable(model_inputs["past_key_values"][2][3]),
                    Variable(model_inputs["past_key_values"][3][0]),
                    Variable(model_inputs["past_key_values"][3][1]),
                    Variable(model_inputs["past_key_values"][3][2]),
                    Variable(model_inputs["past_key_values"][3][3]),
                    Variable(model_inputs["past_key_values"][4][0]),
                    Variable(model_inputs["past_key_values"][4][1]),
                    Variable(model_inputs["past_key_values"][4][2]),
                    Variable(model_inputs["past_key_values"][4][3]),
                    Variable(model_inputs["past_key_values"][5][0]),
                    Variable(model_inputs["past_key_values"][5][1]),
                    Variable(model_inputs["past_key_values"][5][2]),
                    Variable(model_inputs["past_key_values"][5][3]),
                    ### 
                    Variable(model_inputs["past_key_values"][6][0]),
                    Variable(model_inputs["past_key_values"][6][1]),
                    Variable(model_inputs["past_key_values"][6][2]),
                    Variable(model_inputs["past_key_values"][6][3]),
                    Variable(model_inputs["past_key_values"][7][0]),
                    Variable(model_inputs["past_key_values"][7][1]),
                    Variable(model_inputs["past_key_values"][7][2]),
                    Variable(model_inputs["past_key_values"][7][3]),
                    Variable(model_inputs["past_key_values"][8][0]),
                    Variable(model_inputs["past_key_values"][8][1]),
                    Variable(model_inputs["past_key_values"][8][2]),
                    Variable(model_inputs["past_key_values"][8][3]),
                    Variable(model_inputs["past_key_values"][9][0]),
                    Variable(model_inputs["past_key_values"][9][1]),
                    Variable(model_inputs["past_key_values"][9][2]),
                    Variable(model_inputs["past_key_values"][9][3]),
                    Variable(model_inputs["past_key_values"][10][0]),
                    Variable(model_inputs["past_key_values"][10][1]),
                    Variable(model_inputs["past_key_values"][10][2]),
                    Variable(model_inputs["past_key_values"][10][3]),
                    Variable(model_inputs["past_key_values"][11][0]),
                    Variable(model_inputs["past_key_values"][11][1]),
                    Variable(model_inputs["past_key_values"][11][2]),
                    Variable(model_inputs["past_key_values"][11][3]),
                )
                torch.onnx.export(
                    model, xx, 'decoder_model.onnx',
                    input_names=[
                        'input_ids', 'encoder_hidden_states', 
                        'past_key_values.0.decoder.key', 'past_key_values.0.decoder.value', 'past_key_values.0.encoder.key', 'past_key_values.0.encoder.value', 
                        'past_key_values.1.decoder.key', 'past_key_values.1.decoder.value', 'past_key_values.1.encoder.key', 'past_key_values.1.encoder.value', 
                        'past_key_values.2.decoder.key', 'past_key_values.2.decoder.value', 'past_key_values.2.encoder.key', 'past_key_values.2.encoder.value', 
                        'past_key_values.3.decoder.key', 'past_key_values.3.decoder.value', 'past_key_values.3.encoder.key', 'past_key_values.3.encoder.value', 
                        'past_key_values.4.decoder.key', 'past_key_values.4.decoder.value', 'past_key_values.4.encoder.key', 'past_key_values.4.encoder.value', 
                        'past_key_values.5.decoder.key', 'past_key_values.5.decoder.value', 'past_key_values.5.encoder.key', 'past_key_values.5.encoder.value', 
                        ###
                        'past_key_values.6.decoder.key', 'past_key_values.6.decoder.value', 'past_key_values.6.encoder.key', 'past_key_values.6.encoder.value', 
                        'past_key_values.7.decoder.key', 'past_key_values.7.decoder.value', 'past_key_values.7.encoder.key', 'past_key_values.7.encoder.value', 
                        'past_key_values.8.decoder.key', 'past_key_values.8.decoder.value', 'past_key_values.8.encoder.key', 'past_key_values.8.encoder.value', 
                        'past_key_values.9.decoder.key', 'past_key_values.9.decoder.value', 'past_key_values.9.encoder.key', 'past_key_values.9.encoder.value', 
                        'past_key_values.10.decoder.key', 'past_key_values.10.decoder.value', 'past_key_values.10.encoder.key', 'past_key_values.10.encoder.value', 
                        'past_key_values.11.decoder.key', 'past_key_values.11.decoder.value', 'past_key_values.11.encoder.key', 'past_key_values.11.encoder.value', 
                    ],
                    output_names=[
                        'logits',
                        'present.0.decoder.key', 'present.0.decoder.value', 'present.0.encoder.key', 'present.0.encoder.value', 
                        'present.1.decoder.key', 'present.1.decoder.value', 'present.1.encoder.key', 'present.1.encoder.value',
                        'present.2.decoder.key', 'present.2.decoder.value', 'present.2.encoder.key', 'present.2.encoder.value',
                        'present.3.decoder.key', 'present.3.decoder.value', 'present.3.encoder.key', 'present.3.encoder.value',
                        'present.4.decoder.key', 'present.4.decoder.value', 'present.4.encoder.key', 'present.4.encoder.value',
                        'present.5.decoder.key', 'present.5.decoder.value', 'present.5.encoder.key', 'present.5.encoder.value',
                        ###
                        'present.6.decoder.key', 'present.6.decoder.value', 'present.6.encoder.key', 'present.6.encoder.value', 
                        'present.7.decoder.key', 'present.7.decoder.value', 'present.7.encoder.key', 'present.7.encoder.value',
                        'present.8.decoder.key', 'present.8.decoder.value', 'present.8.encoder.key', 'present.8.encoder.value',
                        'present.9.decoder.key', 'present.9.decoder.value', 'present.9.encoder.key', 'present.9.encoder.value',
                        'present.10.decoder.key', 'present.10.decoder.value', 'present.10.encoder.key', 'present.10.encoder.value',
                        'present.11.decoder.key', 'present.11.decoder.value', 'present.11.encoder.key', 'present.11.encoder.value',
                    ],
                    dynamic_axes={
                        'input_ids': {0: 'batch_size', 1: 'decoder_sequence_length'},
                        'encoder_hidden_states': {0: 'batch_size', 1: 'encoder_sequence_length / 2'},
                        'logits': {0: 'batch_size', 1: 'decoder_sequence_length'},
                        'past_key_values.0.decoder.key': {0: 'batch_size', 2: 'past_decoder_sequence_length'},
                        'past_key_values.0.decoder.value': {0: 'batch_size', 2: 'past_decoder_sequence_length'},
                        'past_key_values.0.encoder.key': {0: 'batch_size', 2: 'past_decoder_sequence_length'},
                        'past_key_values.0.encoder.value': {0: 'batch_size', 2: 'past_decoder_sequence_length'},
                        'past_key_values.1.decoder.key': {0: 'batch_size', 2: 'past_decoder_sequence_length'},
                        'past_key_values.1.decoder.value': {0: 'batch_size', 2: 'past_decoder_sequence_length'},
                        'past_key_values.1.encoder.key': {0: 'batch_size', 2: 'past_decoder_sequence_length'},
                        'past_key_values.1.encoder.value': {0: 'batch_size', 2: 'past_decoder_sequence_length'},
                        'past_key_values.2.decoder.key': {0: 'batch_size', 2: 'past_decoder_sequence_length'},
                        'past_key_values.2.decoder.value': {0: 'batch_size', 2: 'past_decoder_sequence_length'},
                        'past_key_values.2.encoder.key': {0: 'batch_size', 2: 'past_decoder_sequence_length'},
                        'past_key_values.2.encoder.value': {0: 'batch_size', 2: 'past_decoder_sequence_length'},
                        'past_key_values.3.decoder.key': {0: 'batch_size', 2: 'past_decoder_sequence_length'},
                        'past_key_values.3.decoder.value': {0: 'batch_size', 2: 'past_decoder_sequence_length'},
                        'past_key_values.3.encoder.key': {0: 'batch_size', 2: 'past_decoder_sequence_length'},
                        'past_key_values.3.encoder.value': {0: 'batch_size', 2: 'past_decoder_sequence_length'},
                        'past_key_values.4.decoder.key': {0: 'batch_size', 2: 'past_decoder_sequence_length'},
                        'past_key_values.4.decoder.value': {0: 'batch_size', 2: 'past_decoder_sequence_length'},
                        'past_key_values.4.encoder.key': {0: 'batch_size', 2: 'past_decoder_sequence_length'},
                        'past_key_values.4.encoder.value': {0: 'batch_size', 2: 'past_decoder_sequence_length'},
                        'past_key_values.5.decoder.key': {0: 'batch_size', 2: 'past_decoder_sequence_length'},
                        'past_key_values.5.decoder.value': {0: 'batch_size', 2: 'past_decoder_sequence_length'},
                        'past_key_values.5.encoder.key': {0: 'batch_size', 2: 'past_decoder_sequence_length'},
                        'past_key_values.5.encoder.value': {0: 'batch_size', 2: 'past_decoder_sequence_length'},
                        ###
                        'past_key_values.6.decoder.key': {0: 'batch_size', 2: 'past_decoder_sequence_length'},
                        'past_key_values.6.decoder.value': {0: 'batch_size', 2: 'past_decoder_sequence_length'},
                        'past_key_values.6.encoder.key': {0: 'batch_size', 2: 'past_decoder_sequence_length'},
                        'past_key_values.6.encoder.value': {0: 'batch_size', 2: 'past_decoder_sequence_length'},
                        'past_key_values.7.decoder.key': {0: 'batch_size', 2: 'past_decoder_sequence_length'},
                        'past_key_values.7.decoder.value': {0: 'batch_size', 2: 'past_decoder_sequence_length'},
                        'past_key_values.7.encoder.key': {0: 'batch_size', 2: 'past_decoder_sequence_length'},
                        'past_key_values.7.encoder.value': {0: 'batch_size', 2: 'past_decoder_sequence_length'},
                        'past_key_values.8.decoder.key': {0: 'batch_size', 2: 'past_decoder_sequence_length'},
                        'past_key_values.8.decoder.value': {0: 'batch_size', 2: 'past_decoder_sequence_length'},
                        'past_key_values.8.encoder.key': {0: 'batch_size', 2: 'past_decoder_sequence_length'},
                        'past_key_values.8.encoder.value': {0: 'batch_size', 2: 'past_decoder_sequence_length'},
                        'past_key_values.9.decoder.key': {0: 'batch_size', 2: 'past_decoder_sequence_length'},
                        'past_key_values.9.decoder.value': {0: 'batch_size', 2: 'past_decoder_sequence_length'},
                        'past_key_values.9.encoder.key': {0: 'batch_size', 2: 'past_decoder_sequence_length'},
                        'past_key_values.9.encoder.value': {0: 'batch_size', 2: 'past_decoder_sequence_length'},
                        'past_key_values.10.decoder.key': {0: 'batch_size', 2: 'past_decoder_sequence_length'},
                        'past_key_values.10.decoder.value': {0: 'batch_size', 2: 'past_decoder_sequence_length'},
                        'past_key_values.10.encoder.key': {0: 'batch_size', 2: 'past_decoder_sequence_length'},
                        'past_key_values.10.encoder.value': {0: 'batch_size', 2: 'past_decoder_sequence_length'},
                        'past_key_values.11.decoder.key': {0: 'batch_size', 2: 'past_decoder_sequence_length'},
                        'past_key_values.11.decoder.value': {0: 'batch_size', 2: 'past_decoder_sequence_length'},
                        'past_key_values.11.encoder.key': {0: 'batch_size', 2: 'past_decoder_sequence_length'},
                        'past_key_values.11.encoder.value': {0: 'batch_size', 2: 'past_decoder_sequence_length'},
                        'present.0.decoder.key': {0: 'batch_size', 2: 'past_decoder_sequence_length + 1'},
                        'present.0.decoder.value': {0: 'batch_size', 2: 'past_decoder_sequence_length + 1'},
                        'present.0.encoder.key': {0: 'batch_size', 2: 'encoder_sequence_length_out'},
                        'present.0.encoder.value': {0: 'batch_size', 2: 'encoder_sequence_length_out'},
                        'present.1.decoder.key': {0: 'batch_size', 2: 'past_decoder_sequence_length + 1'},
                        'present.1.decoder.value': {0: 'batch_size', 2: 'past_decoder_sequence_length + 1'},
                        'present.1.encoder.key': {0: 'batch_size', 2: 'encoder_sequence_length_out'},
                        'present.1.encoder.value': {0: 'batch_size', 2: 'encoder_sequence_length_out'},
                        'present.2.decoder.key': {0: 'batch_size', 2: 'past_decoder_sequence_length + 1'},
                        'present.2.decoder.value': {0: 'batch_size', 2: 'past_decoder_sequence_length + 1'},
                        'present.2.encoder.key': {0: 'batch_size', 2: 'encoder_sequence_length_out'},
                        'present.2.encoder.value': {0: 'batch_size', 2: 'encoder_sequence_length_out'},
                        'present.3.decoder.key': {0: 'batch_size', 2: 'past_decoder_sequence_length + 1'},
                        'present.3.decoder.value': {0: 'batch_size', 2: 'past_decoder_sequence_length + 1'},
                        'present.3.encoder.key': {0: 'batch_size', 2: 'encoder_sequence_length_out'},
                        'present.3.encoder.value': {0: 'batch_size', 2: 'encoder_sequence_length_out'},
                        'present.4.decoder.key': {0: 'batch_size', 2: 'past_decoder_sequence_length + 1'},
                        'present.4.decoder.value': {0: 'batch_size', 2: 'past_decoder_sequence_length + 1'},
                        'present.4.encoder.key': {0: 'batch_size', 2: 'encoder_sequence_length_out'},
                        'present.4.encoder.value': {0: 'batch_size', 2: 'encoder_sequence_length_out'},
                        'present.5.decoder.key': {0: 'batch_size', 2: 'past_decoder_sequence_length + 1'},
                        'present.5.decoder.value': {0: 'batch_size', 2: 'past_decoder_sequence_length + 1'},
                        'present.5.encoder.key': {0: 'batch_size', 2: 'encoder_sequence_length_out'},
                        'present.5.encoder.value': {0: 'batch_size', 2: 'encoder_sequence_length_out'},
                        ###
                        'present.6.decoder.key': {0: 'batch_size', 2: 'past_decoder_sequence_length + 1'},
                        'present.6.decoder.value': {0: 'batch_size', 2: 'past_decoder_sequence_length + 1'},
                        'present.6.encoder.key': {0: 'batch_size', 2: 'encoder_sequence_length_out'},
                        'present.6.encoder.value': {0: 'batch_size', 2: 'encoder_sequence_length_out'},
                        'present.7.decoder.key': {0: 'batch_size', 2: 'past_decoder_sequence_length + 1'},
                        'present.7.decoder.value': {0: 'batch_size', 2: 'past_decoder_sequence_length + 1'},
                        'present.7.encoder.key': {0: 'batch_size', 2: 'encoder_sequence_length_out'},
                        'present.7.encoder.value': {0: 'batch_size', 2: 'encoder_sequence_length_out'},
                        'present.8.decoder.key': {0: 'batch_size', 2: 'past_decoder_sequence_length + 1'},
                        'present.8.decoder.value': {0: 'batch_size', 2: 'past_decoder_sequence_length + 1'},
                        'present.8.encoder.key': {0: 'batch_size', 2: 'encoder_sequence_length_out'},
                        'present.8.encoder.value': {0: 'batch_size', 2: 'encoder_sequence_length_out'},
                        'present.9.decoder.key': {0: 'batch_size', 2: 'past_decoder_sequence_length + 1'},
                        'present.9.decoder.value': {0: 'batch_size', 2: 'past_decoder_sequence_length + 1'},
                        'present.9.encoder.key': {0: 'batch_size', 2: 'encoder_sequence_length_out'},
                        'present.9.encoder.value': {0: 'batch_size', 2: 'encoder_sequence_length_out'},
                        'present.10.decoder.key': {0: 'batch_size', 2: 'past_decoder_sequence_length + 1'},
                        'present.10.decoder.value': {0: 'batch_size', 2: 'past_decoder_sequence_length + 1'},
                        'present.10.encoder.key': {0: 'batch_size', 2: 'encoder_sequence_length_out'},
                        'present.10.encoder.value': {0: 'batch_size', 2: 'encoder_sequence_length_out'},
                        'present.11.decoder.key': {0: 'batch_size', 2: 'past_decoder_sequence_length + 1'},
                        'present.11.decoder.value': {0: 'batch_size', 2: 'past_decoder_sequence_length + 1'},
                        'present.11.encoder.key': {0: 'batch_size', 2: 'encoder_sequence_length_out'},
                        'present.11.encoder.value': {0: 'batch_size', 2: 'encoder_sequence_length_out'},
                    },
                    verbose=False, opset_version=17
                )
                print("<------")
                exit(0)

            if sequential:
                ...
            else:  # Unchanged original behavior
                outputs = self(**model_inputs, return_dict=True)