axinc-ai / ailia-models

The collection of pre-trained, state-of-the-art AI models for ailia SDK
2.04k stars 325 forks source link

ADD llava #1411

Closed kyakuno closed 4 months ago

kyakuno commented 8 months ago

https://github.com/haotian-liu/LLaVA apache

ooe1123 commented 6 months ago

llava-v1.5-7b.onnxエクスポート

〇 transformers/models/llama/modeling_llama.py

class LlamaModel(LlamaPreTrainedModel):
    ...
    def forward(
        ...
    ) -> Union[Tuple, BaseModelOutputWithPast]:
        ...
        # retrieve input_ids and inputs_embeds
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:
            batch_size, seq_length = input_ids.shape[:2]
        elif inputs_embeds is not None:
            batch_size, seq_length = inputs_embeds.shape[:2]
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

class LlamaModel(LlamaPreTrainedModel):
    ...
    def forward(
        ...
    ) -> Union[Tuple, BaseModelOutputWithPast]:
        ...
        inputs_embeds = torch.cat([inputs_embeds, self.embed_tokens(input_ids)], dim=1)
        batch_size, seq_length = inputs_embeds.shape[:2]
class LlamaSdpaAttention(LlamaAttention):
    ...
    # Adapted from LlamaAttention.forward
    def forward(
        ...
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        ...
        attn_output = torch.nn.functional.scaled_dot_product_attention(
            query_states,
            key_states,
            value_states,
            attn_mask=attention_mask,
            dropout_p=self.attention_dropout if self.training else 0.0,
            # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
            is_causal=self.is_causal and attention_mask is None and q_len > 1,
        )

class LlamaSdpaAttention(LlamaAttention):
    ...
    # Adapted from LlamaAttention.forward
    def forward(
        ...
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        ...
        Q = query_states
        K = key_states
        V = value_states
        L, S = Q.size(-2), K.size(-2)
        attn_bias = torch.zeros(L, S, dtype=Q.dtype)
        if torch.onnx.is_in_onnx_export():
            def tril(L, S):
                arange = torch.arange(S)
                mask = arange.expand(S, S)
                arange = arange.unsqueeze(-1)
                mask = torch.le(mask, arange)[:L]
                return mask

            is_causal = torch.gt(q_len, 1).type(torch.int64)
            sel = torch.stack([
                torch.ones(L, S, dtype=torch.bool),
                tril(L, S)
            ])
            mask = sel[is_causal]
        else:
            is_causal=q_len > 1
            if is_causal:
                mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
            else:
                mask = torch.ones(L, S, dtype=torch.bool)
        attn_bias.masked_fill_(mask.logical_not(), float("-inf"))
        attn_bias = attn_bias.to(Q.device)
        if torch.onnx.is_in_onnx_export():
            attn_weight = torch.softmax((Q @ K.transpose(-2, -1) / torch.sqrt(Q.size(-1))) + attn_bias, dim=-1)
        else:
            attn_weight = torch.softmax((Q @ K.transpose(-2, -1) / math.sqrt(Q.size(-1))) + attn_bias, dim=-1)
        attn_output = attn_weight @ V

〇 transformers/generation/utils.py

class GenerationMixin:
    ...
    def greedy_search(
        ...
    ) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
        ...
        while True:
            ...
            outputs = self(
                **model_inputs,
                return_dict=True,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
            )

class GenerationMixin:
    ...
    def greedy_search(
        ...
    ) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
        ...
        while True:
            ...
            b = model_inputs["position_ids"].size(0)
            d = model_inputs["position_ids"].device
            if "input_ids" not in model_inputs:
                model_inputs["input_ids"] = torch.zeros(b, 0, dtype=torch.int64).to(d)
            if "inputs_embeds" not in model_inputs:
                model_inputs["inputs_embeds"] = torch.zeros(b, 0, 4096, dtype=torch.float16).to(d)
            if model_inputs["past_key_values"] is None:
                model_inputs["past_key_values"] = [
                    [
                        torch.zeros(b, 32, 0, 128, dtype=torch.float16).to(d)
                    ] * 2
                ] * 32

            if 0 < model_inputs["past_key_values"][0][0].size(2):
                class Net(nn.Module):
                    def __init__(self, net):
                        super(Net, self).__init__()
                        self.net = net

                    def forward(
                            self, 
                            input_ids, inputs_embeds, 
                            position_ids, attention_mask,
                            past_key_values_0_key, past_key_values_0_value, 
                            past_key_values_1_key, past_key_values_1_value, 
                            past_key_values_2_key, past_key_values_2_value, 
                            past_key_values_3_key, past_key_values_3_value, 
                            past_key_values_4_key, past_key_values_4_value, 
                            past_key_values_5_key, past_key_values_5_value, 
                            past_key_values_6_key, past_key_values_6_value, 
                            past_key_values_7_key, past_key_values_7_value, 
                            past_key_values_8_key, past_key_values_8_value, 
                            past_key_values_9_key, past_key_values_9_value, 
                            past_key_values_10_key, past_key_values_10_value, 
                            past_key_values_11_key, past_key_values_11_value, 
                            past_key_values_12_key, past_key_values_12_value, 
                            past_key_values_13_key, past_key_values_13_value, 
                            past_key_values_14_key, past_key_values_14_value, 
                            past_key_values_15_key, past_key_values_15_value, 
                            past_key_values_16_key, past_key_values_16_value, 
                            past_key_values_17_key, past_key_values_17_value, 
                            past_key_values_18_key, past_key_values_18_value, 
                            past_key_values_19_key, past_key_values_19_value, 
                            past_key_values_20_key, past_key_values_20_value, 
                            past_key_values_21_key, past_key_values_21_value, 
                            past_key_values_22_key, past_key_values_22_value, 
                            past_key_values_23_key, past_key_values_23_value, 
                            past_key_values_24_key, past_key_values_24_value, 
                            past_key_values_25_key, past_key_values_25_value, 
                            past_key_values_26_key, past_key_values_26_value, 
                            past_key_values_27_key, past_key_values_27_value, 
                            past_key_values_28_key, past_key_values_28_value, 
                            past_key_values_29_key, past_key_values_29_value, 
                            past_key_values_30_key, past_key_values_30_value, 
                            past_key_values_31_key, past_key_values_31_value, 
                        ):
                        model_inputs = {
                            "input_ids": input_ids,
                            "inputs_embeds": inputs_embeds,
                            "position_ids": position_ids,
                            "attention_mask": attention_mask,
                            "past_key_values": [
                                [ past_key_values_0_key, past_key_values_0_value ],
                                [ past_key_values_1_key, past_key_values_1_value ],
                                [ past_key_values_2_key, past_key_values_2_value ],
                                [ past_key_values_3_key, past_key_values_3_value ],
                                [ past_key_values_4_key, past_key_values_4_value ],
                                [ past_key_values_5_key, past_key_values_5_value ],
                                [ past_key_values_6_key, past_key_values_6_value ],
                                [ past_key_values_7_key, past_key_values_7_value ],
                                [ past_key_values_8_key, past_key_values_8_value ],
                                [ past_key_values_9_key, past_key_values_9_value ],
                                [ past_key_values_10_key, past_key_values_10_value ],
                                [ past_key_values_11_key, past_key_values_11_value ],
                                [ past_key_values_12_key, past_key_values_12_value ],
                                [ past_key_values_13_key, past_key_values_13_value ],
                                [ past_key_values_14_key, past_key_values_14_value ],
                                [ past_key_values_15_key, past_key_values_15_value ],
                                [ past_key_values_16_key, past_key_values_16_value ],
                                [ past_key_values_17_key, past_key_values_17_value ],
                                [ past_key_values_18_key, past_key_values_18_value ],
                                [ past_key_values_19_key, past_key_values_19_value ],
                                [ past_key_values_20_key, past_key_values_20_value ],
                                [ past_key_values_21_key, past_key_values_21_value ],
                                [ past_key_values_22_key, past_key_values_22_value ],
                                [ past_key_values_23_key, past_key_values_23_value ],
                                [ past_key_values_24_key, past_key_values_24_value ],
                                [ past_key_values_25_key, past_key_values_25_value ],
                                [ past_key_values_26_key, past_key_values_26_value ],
                                [ past_key_values_27_key, past_key_values_27_value ],
                                [ past_key_values_28_key, past_key_values_28_value ],
                                [ past_key_values_29_key, past_key_values_29_value ],
                                [ past_key_values_30_key, past_key_values_30_value ],
                                [ past_key_values_31_key, past_key_values_31_value ],
                            ],
                            "use_cache": True,
                        }
                        outputs = self.net(
                            **model_inputs,
                            return_dict=True,
                            output_attentions=False,
                            output_hidden_states=False,
                        )
                        return (
                            outputs["logits"],
                            outputs["past_key_values"][0][0],
                            outputs["past_key_values"][0][1],
                            outputs["past_key_values"][1][0],
                            outputs["past_key_values"][1][1],
                            outputs["past_key_values"][2][0],
                            outputs["past_key_values"][2][1],
                            outputs["past_key_values"][3][0],
                            outputs["past_key_values"][3][1],
                            outputs["past_key_values"][4][0],
                            outputs["past_key_values"][4][1],
                            outputs["past_key_values"][5][0],
                            outputs["past_key_values"][5][1],
                            outputs["past_key_values"][6][0],
                            outputs["past_key_values"][6][1],
                            outputs["past_key_values"][7][0],
                            outputs["past_key_values"][7][1],
                            outputs["past_key_values"][8][0],
                            outputs["past_key_values"][8][1],
                            outputs["past_key_values"][9][0],
                            outputs["past_key_values"][9][1],
                            outputs["past_key_values"][10][0],
                            outputs["past_key_values"][10][1],
                            outputs["past_key_values"][11][0],
                            outputs["past_key_values"][11][1],
                            outputs["past_key_values"][12][0],
                            outputs["past_key_values"][12][1],
                            outputs["past_key_values"][13][0],
                            outputs["past_key_values"][13][1],
                            outputs["past_key_values"][14][0],
                            outputs["past_key_values"][14][1],
                            outputs["past_key_values"][15][0],
                            outputs["past_key_values"][15][1],
                            outputs["past_key_values"][16][0],
                            outputs["past_key_values"][16][1],
                            outputs["past_key_values"][17][0],
                            outputs["past_key_values"][17][1],
                            outputs["past_key_values"][18][0],
                            outputs["past_key_values"][18][1],
                            outputs["past_key_values"][19][0],
                            outputs["past_key_values"][19][1],
                            outputs["past_key_values"][20][0],
                            outputs["past_key_values"][20][1],
                            outputs["past_key_values"][21][0],
                            outputs["past_key_values"][21][1],
                            outputs["past_key_values"][22][0],
                            outputs["past_key_values"][22][1],
                            outputs["past_key_values"][23][0],
                            outputs["past_key_values"][23][1],
                            outputs["past_key_values"][24][0],
                            outputs["past_key_values"][24][1],
                            outputs["past_key_values"][25][0],
                            outputs["past_key_values"][25][1],
                            outputs["past_key_values"][26][0],
                            outputs["past_key_values"][26][1],
                            outputs["past_key_values"][27][0],
                            outputs["past_key_values"][27][1],
                            outputs["past_key_values"][28][0],
                            outputs["past_key_values"][28][1],
                            outputs["past_key_values"][29][0],
                            outputs["past_key_values"][29][1],
                            outputs["past_key_values"][30][0],
                            outputs["past_key_values"][30][1],
                            outputs["past_key_values"][31][0],
                            outputs["past_key_values"][31][1],
                        )

                model = Net(self)
                from torch.autograd import Variable
                xx = (
                    Variable(model_inputs["input_ids"]),
                    Variable(model_inputs["inputs_embeds"]),
                    Variable(model_inputs["position_ids"]),
                    Variable(model_inputs["attention_mask"]),
                    Variable(model_inputs["past_key_values"][0][0]),
                    Variable(model_inputs["past_key_values"][0][1]),
                    Variable(model_inputs["past_key_values"][1][0]),
                    Variable(model_inputs["past_key_values"][1][1]),
                    Variable(model_inputs["past_key_values"][2][0]),
                    Variable(model_inputs["past_key_values"][2][1]),
                    Variable(model_inputs["past_key_values"][3][0]),
                    Variable(model_inputs["past_key_values"][3][1]),
                    Variable(model_inputs["past_key_values"][4][0]),
                    Variable(model_inputs["past_key_values"][4][1]),
                    Variable(model_inputs["past_key_values"][5][0]),
                    Variable(model_inputs["past_key_values"][5][1]),
                    Variable(model_inputs["past_key_values"][6][0]),
                    Variable(model_inputs["past_key_values"][6][1]),
                    Variable(model_inputs["past_key_values"][7][0]),
                    Variable(model_inputs["past_key_values"][7][1]),
                    Variable(model_inputs["past_key_values"][8][0]),
                    Variable(model_inputs["past_key_values"][8][1]),
                    Variable(model_inputs["past_key_values"][9][0]),
                    Variable(model_inputs["past_key_values"][9][1]),
                    Variable(model_inputs["past_key_values"][10][0]),
                    Variable(model_inputs["past_key_values"][10][1]),
                    Variable(model_inputs["past_key_values"][11][0]),
                    Variable(model_inputs["past_key_values"][11][1]),
                    Variable(model_inputs["past_key_values"][12][0]),
                    Variable(model_inputs["past_key_values"][12][1]),
                    Variable(model_inputs["past_key_values"][13][0]),
                    Variable(model_inputs["past_key_values"][13][1]),
                    Variable(model_inputs["past_key_values"][14][0]),
                    Variable(model_inputs["past_key_values"][14][1]),
                    Variable(model_inputs["past_key_values"][15][0]),
                    Variable(model_inputs["past_key_values"][15][1]),
                    Variable(model_inputs["past_key_values"][16][0]),
                    Variable(model_inputs["past_key_values"][16][1]),
                    Variable(model_inputs["past_key_values"][17][0]),
                    Variable(model_inputs["past_key_values"][17][1]),
                    Variable(model_inputs["past_key_values"][18][0]),
                    Variable(model_inputs["past_key_values"][18][1]),
                    Variable(model_inputs["past_key_values"][19][0]),
                    Variable(model_inputs["past_key_values"][19][1]),
                    Variable(model_inputs["past_key_values"][20][0]),
                    Variable(model_inputs["past_key_values"][20][1]),
                    Variable(model_inputs["past_key_values"][21][0]),
                    Variable(model_inputs["past_key_values"][21][1]),
                    Variable(model_inputs["past_key_values"][22][0]),
                    Variable(model_inputs["past_key_values"][22][1]),
                    Variable(model_inputs["past_key_values"][23][0]),
                    Variable(model_inputs["past_key_values"][23][1]),
                    Variable(model_inputs["past_key_values"][24][0]),
                    Variable(model_inputs["past_key_values"][24][1]),
                    Variable(model_inputs["past_key_values"][25][0]),
                    Variable(model_inputs["past_key_values"][25][1]),
                    Variable(model_inputs["past_key_values"][26][0]),
                    Variable(model_inputs["past_key_values"][26][1]),
                    Variable(model_inputs["past_key_values"][27][0]),
                    Variable(model_inputs["past_key_values"][27][1]),
                    Variable(model_inputs["past_key_values"][28][0]),
                    Variable(model_inputs["past_key_values"][28][1]),
                    Variable(model_inputs["past_key_values"][29][0]),
                    Variable(model_inputs["past_key_values"][29][1]),
                    Variable(model_inputs["past_key_values"][30][0]),
                    Variable(model_inputs["past_key_values"][30][1]),
                    Variable(model_inputs["past_key_values"][31][0]),
                    Variable(model_inputs["past_key_values"][31][1]),
                )
                print("------>")
                torch.onnx.export(
                    model, xx, 'onnx/llava-v1.5-7b.onnx',
                    input_names=[
                        'input_ids', 'inputs_embeds', 
                        'position_ids', 'attention_mask', 
                        'past_key_values.0.decoder.key', 'past_key_values.0.decoder.value',
                        'past_key_values.0.encoder.key', 'past_key_values.0.encoder.value',
                        'past_key_values.1.decoder.key', 'past_key_values.1.decoder.value',
                        'past_key_values.1.encoder.key', 'past_key_values.1.encoder.value',
                        'past_key_values.2.decoder.key', 'past_key_values.2.decoder.value',
                        'past_key_values.2.encoder.key', 'past_key_values.2.encoder.value',
                        'past_key_values.3.decoder.key', 'past_key_values.3.decoder.value',
                        'past_key_values.3.encoder.key', 'past_key_values.3.encoder.value',
                        'past_key_values.4.decoder.key', 'past_key_values.4.decoder.value',
                        'past_key_values.4.encoder.key', 'past_key_values.4.encoder.value',
                        'past_key_values.5.decoder.key', 'past_key_values.5.decoder.value',
                        'past_key_values.5.encoder.key', 'past_key_values.5.encoder.value',
                        'past_key_values.6.decoder.key', 'past_key_values.6.decoder.value',
                        'past_key_values.6.encoder.key', 'past_key_values.6.encoder.value',
                        'past_key_values.7.decoder.key', 'past_key_values.7.decoder.value',
                        'past_key_values.7.encoder.key', 'past_key_values.7.encoder.value',
                        'past_key_values.8.decoder.key', 'past_key_values.8.decoder.value',
                        'past_key_values.8.encoder.key', 'past_key_values.8.encoder.value',
                        'past_key_values.9.decoder.key', 'past_key_values.9.decoder.value',
                        'past_key_values.9.encoder.key', 'past_key_values.9.encoder.value',
                        'past_key_values.10.decoder.key', 'past_key_values.10.decoder.value',
                        'past_key_values.10.encoder.key', 'past_key_values.10.encoder.value',
                        'past_key_values.11.decoder.key', 'past_key_values.11.decoder.value',
                        'past_key_values.11.encoder.key', 'past_key_values.11.encoder.value',
                        'past_key_values.12.decoder.key', 'past_key_values.12.decoder.value',
                        'past_key_values.12.encoder.key', 'past_key_values.12.encoder.value',
                        'past_key_values.13.decoder.key', 'past_key_values.13.decoder.value',
                        'past_key_values.13.encoder.key', 'past_key_values.13.encoder.value',
                        'past_key_values.14.decoder.key', 'past_key_values.14.decoder.value',
                        'past_key_values.14.encoder.key', 'past_key_values.14.encoder.value',
                        'past_key_values.15.decoder.key', 'past_key_values.15.decoder.value',
                        'past_key_values.15.encoder.key', 'past_key_values.15.encoder.value',
                    ],
                    output_names=[
                        'logits',
                        'present.0.decoder.key', 'present.0.decoder.value',
                        'present.0.encoder.key', 'present.0.encoder.value',
                        'present.1.decoder.key', 'present.1.decoder.value',
                        'present.1.encoder.key', 'present.1.encoder.value',
                        'present.2.decoder.key', 'present.2.decoder.value',
                        'present.2.encoder.key', 'present.2.encoder.value',
                        'present.3.decoder.key', 'present.3.decoder.value',
                        'present.3.encoder.key', 'present.3.encoder.value',
                        'present.4.decoder.key', 'present.4.decoder.value',
                        'present.4.encoder.key', 'present.4.encoder.value',
                        'present.5.decoder.key', 'present.5.decoder.value',
                        'present.5.encoder.key', 'present.5.encoder.value',
                        'present.6.decoder.key', 'present.6.decoder.value',
                        'present.6.encoder.key', 'present.6.encoder.value',
                        'present.7.decoder.key', 'present.7.decoder.value',
                        'present.7.encoder.key', 'present.7.encoder.value',
                        'present.8.decoder.key', 'present.8.decoder.value',
                        'present.8.encoder.key', 'present.8.encoder.value',
                        'present.9.decoder.key', 'present.9.decoder.value',
                        'present.9.encoder.key', 'present.9.encoder.value',
                        'present.10.decoder.key', 'present.10.decoder.value',
                        'present.10.encoder.key', 'present.10.encoder.value',
                        'present.11.decoder.key', 'present.11.decoder.value',
                        'present.11.encoder.key', 'present.11.encoder.value',
                        'present.12.decoder.key', 'present.12.decoder.value',
                        'present.12.encoder.key', 'present.12.encoder.value',
                        'present.13.decoder.key', 'present.13.decoder.value',
                        'present.13.encoder.key', 'present.13.encoder.value',
                        'present.14.decoder.key', 'present.14.decoder.value',
                        'present.14.encoder.key', 'present.14.encoder.value',
                        'present.15.decoder.key', 'present.15.decoder.value',
                        'present.15.encoder.key', 'present.15.encoder.value',
                    ],
                    dynamic_axes={
                        'input_ids': [0, 1],
                        'inputs_embeds': [0, 1],
                        'logits': [0, 1],
                        'position_ids': [0, 1],
                        'attention_mask': [0, 1],
                        'past_key_values.0.decoder.key': [0, 2],
                        'past_key_values.0.decoder.value': [0, 2],
                        'past_key_values.0.encoder.key': [0, 2],
                        'past_key_values.0.encoder.value': [0, 2],
                        'past_key_values.1.decoder.key': [0, 2],
                        'past_key_values.1.decoder.value': [0, 2],
                        'past_key_values.1.encoder.key': [0, 2],
                        'past_key_values.1.encoder.value': [0, 2],
                        'past_key_values.2.decoder.key': [0, 2],
                        'past_key_values.2.decoder.value': [0, 2],
                        'past_key_values.2.encoder.key': [0, 2],
                        'past_key_values.2.encoder.value': [0, 2],
                        'past_key_values.3.decoder.key': [0, 2],
                        'past_key_values.3.decoder.value': [0, 2],
                        'past_key_values.3.encoder.key': [0, 2],
                        'past_key_values.3.encoder.value': [0, 2],
                        'past_key_values.4.decoder.key': [0, 2],
                        'past_key_values.4.decoder.value': [0, 2],
                        'past_key_values.4.encoder.key': [0, 2],
                        'past_key_values.4.encoder.value': [0, 2],
                        'past_key_values.5.decoder.key': [0, 2],
                        'past_key_values.5.decoder.value': [0, 2],
                        'past_key_values.5.encoder.key': [0, 2],
                        'past_key_values.5.encoder.value': [0, 2],
                        'past_key_values.6.decoder.key': [0, 2],
                        'past_key_values.6.decoder.value': [0, 2],
                        'past_key_values.6.encoder.key': [0, 2],
                        'past_key_values.6.encoder.value': [0, 2],
                        'past_key_values.7.decoder.key': [0, 2],
                        'past_key_values.7.decoder.value': [0, 2],
                        'past_key_values.7.encoder.key': [0, 2],
                        'past_key_values.7.encoder.value': [0, 2],
                        'past_key_values.8.decoder.key': [0, 2],
                        'past_key_values.8.decoder.value': [0, 2],
                        'past_key_values.8.encoder.key': [0, 2],
                        'past_key_values.8.encoder.value': [0, 2],
                        'past_key_values.9.decoder.key': [0, 2],
                        'past_key_values.9.decoder.value': [0, 2],
                        'past_key_values.9.encoder.key': [0, 2],
                        'past_key_values.9.encoder.value': [0, 2],
                        'past_key_values.10.decoder.key': [0, 2],
                        'past_key_values.10.decoder.value': [0, 2],
                        'past_key_values.10.encoder.key': [0, 2],
                        'past_key_values.10.encoder.value': [0, 2],
                        'past_key_values.11.decoder.key': [0, 2],
                        'past_key_values.11.decoder.value': [0, 2],
                        'past_key_values.11.encoder.key': [0, 2],
                        'past_key_values.11.encoder.value': [0, 2],
                        'past_key_values.12.decoder.key': [0, 2],
                        'past_key_values.12.decoder.value': [0, 2],
                        'past_key_values.12.encoder.key': [0, 2],
                        'past_key_values.12.encoder.value': [0, 2],
                        'past_key_values.13.decoder.key': [0, 2],
                        'past_key_values.13.decoder.value': [0, 2],
                        'past_key_values.13.encoder.key': [0, 2],
                        'past_key_values.13.encoder.value': [0, 2],
                        'past_key_values.14.decoder.key': [0, 2],
                        'past_key_values.14.decoder.value': [0, 2],
                        'past_key_values.14.encoder.key': [0, 2],
                        'past_key_values.14.encoder.value': [0, 2],
                        'past_key_values.15.decoder.key': [0, 2],
                        'past_key_values.15.decoder.value': [0, 2],
                        'past_key_values.15.encoder.key': [0, 2],
                        'past_key_values.15.encoder.value': [0, 2],
                        'present.0.decoder.key': [0, 2],
                        'present.0.decoder.value': [0, 2],
                        'present.0.encoder.key': [0, 2],
                        'present.0.encoder.value': [0, 2],
                        'present.1.decoder.key': [0, 2],
                        'present.1.decoder.value': [0, 2],
                        'present.1.encoder.key': [0, 2],
                        'present.1.encoder.value': [0, 2],
                        'present.2.decoder.key': [0, 2],
                        'present.2.decoder.value': [0, 2],
                        'present.2.encoder.key': [0, 2],
                        'present.2.encoder.value': [0, 2],
                        'present.3.decoder.key': [0, 2],
                        'present.3.decoder.value': [0, 2],
                        'present.3.encoder.key': [0, 2],
                        'present.3.encoder.value': [0, 2],
                        'present.4.decoder.key': [0, 2],
                        'present.4.decoder.value': [0, 2],
                        'present.4.encoder.key': [0, 2],
                        'present.4.encoder.value': [0, 2],
                        'present.5.decoder.key': [0, 2],
                        'present.5.decoder.value': [0, 2],
                        'present.5.encoder.key': [0, 2],
                        'present.5.encoder.value': [0, 2],
                        'present.6.decoder.key': [0, 2],
                        'present.6.decoder.value': [0, 2],
                        'present.6.encoder.key': [0, 2],
                        'present.6.encoder.value': [0, 2],
                        'present.7.decoder.key': [0, 2],
                        'present.7.decoder.value': [0, 2],
                        'present.7.encoder.key': [0, 2],
                        'present.7.encoder.value': [0, 2],
                        'present.8.decoder.key': [0, 2],
                        'present.8.decoder.value': [0, 2],
                        'present.8.encoder.key': [0, 2],
                        'present.8.encoder.value': [0, 2],
                        'present.9.decoder.key': [0, 2],
                        'present.9.decoder.value': [0, 2],
                        'present.9.encoder.key': [0, 2],
                        'present.9.encoder.value': [0, 2],
                        'present.10.decoder.key': [0, 2],
                        'present.10.decoder.value': [0, 2],
                        'present.10.encoder.key': [0, 2],
                        'present.10.encoder.value': [0, 2],
                        'present.11.decoder.key': [0, 2],
                        'present.11.decoder.value': [0, 2],
                        'present.11.encoder.key': [0, 2],
                        'present.11.encoder.value': [0, 2],
                        'present.12.decoder.key': [0, 2],
                        'present.12.decoder.value': [0, 2],
                        'present.12.encoder.key': [0, 2],
                        'present.12.encoder.value': [0, 2],
                        'present.13.decoder.key': [0, 2],
                        'present.13.decoder.value': [0, 2],
                        'present.13.encoder.key': [0, 2],
                        'present.13.encoder.value': [0, 2],
                        'present.14.decoder.key': [0, 2],
                        'present.14.decoder.value': [0, 2],
                        'present.14.encoder.key': [0, 2],
                        'present.14.encoder.value': [0, 2],
                        'present.15.decoder.key': [0, 2],
                        'present.15.decoder.value': [0, 2],
                        'present.15.encoder.key': [0, 2],
                        'present.15.encoder.value': [0, 2],
                    },
                    verbose=False, opset_version=14
                )
                print("<------")
                exit()
ooe1123 commented 6 months ago

encode_imagesエクスポート

〇 LLaVA/llava/model/llava_arch.py

class LlavaMetaForCausalLM(ABC):
    ...
    def prepare_inputs_labels_for_multimodal(
        ...
    ):
        ...
        if type(images) is list or images.ndim == 5:
           ...
        else:
            image_features = self.encode_images(images)

class Exp(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.vision_tower = model.get_vision_tower()
        self.mm_projector = model.mm_projector

    def forward(self, images):
        image_features = self.vision_tower(images)
        image_features = self.mm_projector(image_features)
        return image_features

class LlavaMetaForCausalLM(ABC):
    ...
    def prepare_inputs_labels_for_multimodal(
        ...
    ):
        ...
        if type(images) is list or images.ndim == 5:
           ...
        else:
            net = Exp(self.get_model())
            image_features = net(images)
            print("------>")
            from torch.autograd import Variable
            x = Variable(images)
            torch.onnx.export(
                net, x, 'encode_images.onnx',
                input_names=["images"],
                output_names=["image_features"],
                dynamic_axes={'images': [0], 'image_features': [0]},
                verbose=False, opset_version=14
            )
            print("<------")
            exit()

embed_tokensエクスポート

〇 LLaVA/llava/model/llava_arch.py

class LlavaMetaForCausalLM(ABC):
    ...
    def prepare_inputs_labels_for_multimodal(
        ...
    ):
        ...
        for batch_idx, cur_input_ids in enumerate(input_ids):
            ...
            cur_input_embeds = self.get_model().embed_tokens(
                torch.cat(cur_input_ids_noim)
            )

class LlavaMetaForCausalLM(ABC):
    ...
    def prepare_inputs_labels_for_multimodal(
        ...
    ):
        ...
        for batch_idx, cur_input_ids in enumerate(input_ids):
            ...
            print("------>")
            from torch.autograd import Variable
            x = Variable(torch.cat(cur_input_ids_noim))
            torch.onnx.export(
                self.get_model().embed_tokens, x, 'embed_tokens.onnx',
                input_names=["input_ids"],
                output_names=["embeds"],
                dynamic_axes={'input_ids': [0], 'embeds': [0]},
                verbose=False, opset_version=14
            )
            print("<------")
            exit()