Closed kyakuno closed 4 months ago
llava-v1.5-7b.onnxエクスポート
〇 transformers/models/llama/modeling_llama.py
class LlamaModel(LlamaPreTrainedModel):
...
def forward(
...
) -> Union[Tuple, BaseModelOutputWithPast]:
...
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape[:2]
elif inputs_embeds is not None:
batch_size, seq_length = inputs_embeds.shape[:2]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
↓
class LlamaModel(LlamaPreTrainedModel):
...
def forward(
...
) -> Union[Tuple, BaseModelOutputWithPast]:
...
inputs_embeds = torch.cat([inputs_embeds, self.embed_tokens(input_ids)], dim=1)
batch_size, seq_length = inputs_embeds.shape[:2]
class LlamaSdpaAttention(LlamaAttention):
...
# Adapted from LlamaAttention.forward
def forward(
...
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
...
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
is_causal=self.is_causal and attention_mask is None and q_len > 1,
)
↓
class LlamaSdpaAttention(LlamaAttention):
...
# Adapted from LlamaAttention.forward
def forward(
...
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
...
Q = query_states
K = key_states
V = value_states
L, S = Q.size(-2), K.size(-2)
attn_bias = torch.zeros(L, S, dtype=Q.dtype)
if torch.onnx.is_in_onnx_export():
def tril(L, S):
arange = torch.arange(S)
mask = arange.expand(S, S)
arange = arange.unsqueeze(-1)
mask = torch.le(mask, arange)[:L]
return mask
is_causal = torch.gt(q_len, 1).type(torch.int64)
sel = torch.stack([
torch.ones(L, S, dtype=torch.bool),
tril(L, S)
])
mask = sel[is_causal]
else:
is_causal=q_len > 1
if is_causal:
mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
else:
mask = torch.ones(L, S, dtype=torch.bool)
attn_bias.masked_fill_(mask.logical_not(), float("-inf"))
attn_bias = attn_bias.to(Q.device)
if torch.onnx.is_in_onnx_export():
attn_weight = torch.softmax((Q @ K.transpose(-2, -1) / torch.sqrt(Q.size(-1))) + attn_bias, dim=-1)
else:
attn_weight = torch.softmax((Q @ K.transpose(-2, -1) / math.sqrt(Q.size(-1))) + attn_bias, dim=-1)
attn_output = attn_weight @ V
〇 transformers/generation/utils.py
class GenerationMixin:
...
def greedy_search(
...
) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
...
while True:
...
outputs = self(
**model_inputs,
return_dict=True,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
↓
class GenerationMixin:
...
def greedy_search(
...
) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
...
while True:
...
b = model_inputs["position_ids"].size(0)
d = model_inputs["position_ids"].device
if "input_ids" not in model_inputs:
model_inputs["input_ids"] = torch.zeros(b, 0, dtype=torch.int64).to(d)
if "inputs_embeds" not in model_inputs:
model_inputs["inputs_embeds"] = torch.zeros(b, 0, 4096, dtype=torch.float16).to(d)
if model_inputs["past_key_values"] is None:
model_inputs["past_key_values"] = [
[
torch.zeros(b, 32, 0, 128, dtype=torch.float16).to(d)
] * 2
] * 32
if 0 < model_inputs["past_key_values"][0][0].size(2):
class Net(nn.Module):
def __init__(self, net):
super(Net, self).__init__()
self.net = net
def forward(
self,
input_ids, inputs_embeds,
position_ids, attention_mask,
past_key_values_0_key, past_key_values_0_value,
past_key_values_1_key, past_key_values_1_value,
past_key_values_2_key, past_key_values_2_value,
past_key_values_3_key, past_key_values_3_value,
past_key_values_4_key, past_key_values_4_value,
past_key_values_5_key, past_key_values_5_value,
past_key_values_6_key, past_key_values_6_value,
past_key_values_7_key, past_key_values_7_value,
past_key_values_8_key, past_key_values_8_value,
past_key_values_9_key, past_key_values_9_value,
past_key_values_10_key, past_key_values_10_value,
past_key_values_11_key, past_key_values_11_value,
past_key_values_12_key, past_key_values_12_value,
past_key_values_13_key, past_key_values_13_value,
past_key_values_14_key, past_key_values_14_value,
past_key_values_15_key, past_key_values_15_value,
past_key_values_16_key, past_key_values_16_value,
past_key_values_17_key, past_key_values_17_value,
past_key_values_18_key, past_key_values_18_value,
past_key_values_19_key, past_key_values_19_value,
past_key_values_20_key, past_key_values_20_value,
past_key_values_21_key, past_key_values_21_value,
past_key_values_22_key, past_key_values_22_value,
past_key_values_23_key, past_key_values_23_value,
past_key_values_24_key, past_key_values_24_value,
past_key_values_25_key, past_key_values_25_value,
past_key_values_26_key, past_key_values_26_value,
past_key_values_27_key, past_key_values_27_value,
past_key_values_28_key, past_key_values_28_value,
past_key_values_29_key, past_key_values_29_value,
past_key_values_30_key, past_key_values_30_value,
past_key_values_31_key, past_key_values_31_value,
):
model_inputs = {
"input_ids": input_ids,
"inputs_embeds": inputs_embeds,
"position_ids": position_ids,
"attention_mask": attention_mask,
"past_key_values": [
[ past_key_values_0_key, past_key_values_0_value ],
[ past_key_values_1_key, past_key_values_1_value ],
[ past_key_values_2_key, past_key_values_2_value ],
[ past_key_values_3_key, past_key_values_3_value ],
[ past_key_values_4_key, past_key_values_4_value ],
[ past_key_values_5_key, past_key_values_5_value ],
[ past_key_values_6_key, past_key_values_6_value ],
[ past_key_values_7_key, past_key_values_7_value ],
[ past_key_values_8_key, past_key_values_8_value ],
[ past_key_values_9_key, past_key_values_9_value ],
[ past_key_values_10_key, past_key_values_10_value ],
[ past_key_values_11_key, past_key_values_11_value ],
[ past_key_values_12_key, past_key_values_12_value ],
[ past_key_values_13_key, past_key_values_13_value ],
[ past_key_values_14_key, past_key_values_14_value ],
[ past_key_values_15_key, past_key_values_15_value ],
[ past_key_values_16_key, past_key_values_16_value ],
[ past_key_values_17_key, past_key_values_17_value ],
[ past_key_values_18_key, past_key_values_18_value ],
[ past_key_values_19_key, past_key_values_19_value ],
[ past_key_values_20_key, past_key_values_20_value ],
[ past_key_values_21_key, past_key_values_21_value ],
[ past_key_values_22_key, past_key_values_22_value ],
[ past_key_values_23_key, past_key_values_23_value ],
[ past_key_values_24_key, past_key_values_24_value ],
[ past_key_values_25_key, past_key_values_25_value ],
[ past_key_values_26_key, past_key_values_26_value ],
[ past_key_values_27_key, past_key_values_27_value ],
[ past_key_values_28_key, past_key_values_28_value ],
[ past_key_values_29_key, past_key_values_29_value ],
[ past_key_values_30_key, past_key_values_30_value ],
[ past_key_values_31_key, past_key_values_31_value ],
],
"use_cache": True,
}
outputs = self.net(
**model_inputs,
return_dict=True,
output_attentions=False,
output_hidden_states=False,
)
return (
outputs["logits"],
outputs["past_key_values"][0][0],
outputs["past_key_values"][0][1],
outputs["past_key_values"][1][0],
outputs["past_key_values"][1][1],
outputs["past_key_values"][2][0],
outputs["past_key_values"][2][1],
outputs["past_key_values"][3][0],
outputs["past_key_values"][3][1],
outputs["past_key_values"][4][0],
outputs["past_key_values"][4][1],
outputs["past_key_values"][5][0],
outputs["past_key_values"][5][1],
outputs["past_key_values"][6][0],
outputs["past_key_values"][6][1],
outputs["past_key_values"][7][0],
outputs["past_key_values"][7][1],
outputs["past_key_values"][8][0],
outputs["past_key_values"][8][1],
outputs["past_key_values"][9][0],
outputs["past_key_values"][9][1],
outputs["past_key_values"][10][0],
outputs["past_key_values"][10][1],
outputs["past_key_values"][11][0],
outputs["past_key_values"][11][1],
outputs["past_key_values"][12][0],
outputs["past_key_values"][12][1],
outputs["past_key_values"][13][0],
outputs["past_key_values"][13][1],
outputs["past_key_values"][14][0],
outputs["past_key_values"][14][1],
outputs["past_key_values"][15][0],
outputs["past_key_values"][15][1],
outputs["past_key_values"][16][0],
outputs["past_key_values"][16][1],
outputs["past_key_values"][17][0],
outputs["past_key_values"][17][1],
outputs["past_key_values"][18][0],
outputs["past_key_values"][18][1],
outputs["past_key_values"][19][0],
outputs["past_key_values"][19][1],
outputs["past_key_values"][20][0],
outputs["past_key_values"][20][1],
outputs["past_key_values"][21][0],
outputs["past_key_values"][21][1],
outputs["past_key_values"][22][0],
outputs["past_key_values"][22][1],
outputs["past_key_values"][23][0],
outputs["past_key_values"][23][1],
outputs["past_key_values"][24][0],
outputs["past_key_values"][24][1],
outputs["past_key_values"][25][0],
outputs["past_key_values"][25][1],
outputs["past_key_values"][26][0],
outputs["past_key_values"][26][1],
outputs["past_key_values"][27][0],
outputs["past_key_values"][27][1],
outputs["past_key_values"][28][0],
outputs["past_key_values"][28][1],
outputs["past_key_values"][29][0],
outputs["past_key_values"][29][1],
outputs["past_key_values"][30][0],
outputs["past_key_values"][30][1],
outputs["past_key_values"][31][0],
outputs["past_key_values"][31][1],
)
model = Net(self)
from torch.autograd import Variable
xx = (
Variable(model_inputs["input_ids"]),
Variable(model_inputs["inputs_embeds"]),
Variable(model_inputs["position_ids"]),
Variable(model_inputs["attention_mask"]),
Variable(model_inputs["past_key_values"][0][0]),
Variable(model_inputs["past_key_values"][0][1]),
Variable(model_inputs["past_key_values"][1][0]),
Variable(model_inputs["past_key_values"][1][1]),
Variable(model_inputs["past_key_values"][2][0]),
Variable(model_inputs["past_key_values"][2][1]),
Variable(model_inputs["past_key_values"][3][0]),
Variable(model_inputs["past_key_values"][3][1]),
Variable(model_inputs["past_key_values"][4][0]),
Variable(model_inputs["past_key_values"][4][1]),
Variable(model_inputs["past_key_values"][5][0]),
Variable(model_inputs["past_key_values"][5][1]),
Variable(model_inputs["past_key_values"][6][0]),
Variable(model_inputs["past_key_values"][6][1]),
Variable(model_inputs["past_key_values"][7][0]),
Variable(model_inputs["past_key_values"][7][1]),
Variable(model_inputs["past_key_values"][8][0]),
Variable(model_inputs["past_key_values"][8][1]),
Variable(model_inputs["past_key_values"][9][0]),
Variable(model_inputs["past_key_values"][9][1]),
Variable(model_inputs["past_key_values"][10][0]),
Variable(model_inputs["past_key_values"][10][1]),
Variable(model_inputs["past_key_values"][11][0]),
Variable(model_inputs["past_key_values"][11][1]),
Variable(model_inputs["past_key_values"][12][0]),
Variable(model_inputs["past_key_values"][12][1]),
Variable(model_inputs["past_key_values"][13][0]),
Variable(model_inputs["past_key_values"][13][1]),
Variable(model_inputs["past_key_values"][14][0]),
Variable(model_inputs["past_key_values"][14][1]),
Variable(model_inputs["past_key_values"][15][0]),
Variable(model_inputs["past_key_values"][15][1]),
Variable(model_inputs["past_key_values"][16][0]),
Variable(model_inputs["past_key_values"][16][1]),
Variable(model_inputs["past_key_values"][17][0]),
Variable(model_inputs["past_key_values"][17][1]),
Variable(model_inputs["past_key_values"][18][0]),
Variable(model_inputs["past_key_values"][18][1]),
Variable(model_inputs["past_key_values"][19][0]),
Variable(model_inputs["past_key_values"][19][1]),
Variable(model_inputs["past_key_values"][20][0]),
Variable(model_inputs["past_key_values"][20][1]),
Variable(model_inputs["past_key_values"][21][0]),
Variable(model_inputs["past_key_values"][21][1]),
Variable(model_inputs["past_key_values"][22][0]),
Variable(model_inputs["past_key_values"][22][1]),
Variable(model_inputs["past_key_values"][23][0]),
Variable(model_inputs["past_key_values"][23][1]),
Variable(model_inputs["past_key_values"][24][0]),
Variable(model_inputs["past_key_values"][24][1]),
Variable(model_inputs["past_key_values"][25][0]),
Variable(model_inputs["past_key_values"][25][1]),
Variable(model_inputs["past_key_values"][26][0]),
Variable(model_inputs["past_key_values"][26][1]),
Variable(model_inputs["past_key_values"][27][0]),
Variable(model_inputs["past_key_values"][27][1]),
Variable(model_inputs["past_key_values"][28][0]),
Variable(model_inputs["past_key_values"][28][1]),
Variable(model_inputs["past_key_values"][29][0]),
Variable(model_inputs["past_key_values"][29][1]),
Variable(model_inputs["past_key_values"][30][0]),
Variable(model_inputs["past_key_values"][30][1]),
Variable(model_inputs["past_key_values"][31][0]),
Variable(model_inputs["past_key_values"][31][1]),
)
print("------>")
torch.onnx.export(
model, xx, 'onnx/llava-v1.5-7b.onnx',
input_names=[
'input_ids', 'inputs_embeds',
'position_ids', 'attention_mask',
'past_key_values.0.decoder.key', 'past_key_values.0.decoder.value',
'past_key_values.0.encoder.key', 'past_key_values.0.encoder.value',
'past_key_values.1.decoder.key', 'past_key_values.1.decoder.value',
'past_key_values.1.encoder.key', 'past_key_values.1.encoder.value',
'past_key_values.2.decoder.key', 'past_key_values.2.decoder.value',
'past_key_values.2.encoder.key', 'past_key_values.2.encoder.value',
'past_key_values.3.decoder.key', 'past_key_values.3.decoder.value',
'past_key_values.3.encoder.key', 'past_key_values.3.encoder.value',
'past_key_values.4.decoder.key', 'past_key_values.4.decoder.value',
'past_key_values.4.encoder.key', 'past_key_values.4.encoder.value',
'past_key_values.5.decoder.key', 'past_key_values.5.decoder.value',
'past_key_values.5.encoder.key', 'past_key_values.5.encoder.value',
'past_key_values.6.decoder.key', 'past_key_values.6.decoder.value',
'past_key_values.6.encoder.key', 'past_key_values.6.encoder.value',
'past_key_values.7.decoder.key', 'past_key_values.7.decoder.value',
'past_key_values.7.encoder.key', 'past_key_values.7.encoder.value',
'past_key_values.8.decoder.key', 'past_key_values.8.decoder.value',
'past_key_values.8.encoder.key', 'past_key_values.8.encoder.value',
'past_key_values.9.decoder.key', 'past_key_values.9.decoder.value',
'past_key_values.9.encoder.key', 'past_key_values.9.encoder.value',
'past_key_values.10.decoder.key', 'past_key_values.10.decoder.value',
'past_key_values.10.encoder.key', 'past_key_values.10.encoder.value',
'past_key_values.11.decoder.key', 'past_key_values.11.decoder.value',
'past_key_values.11.encoder.key', 'past_key_values.11.encoder.value',
'past_key_values.12.decoder.key', 'past_key_values.12.decoder.value',
'past_key_values.12.encoder.key', 'past_key_values.12.encoder.value',
'past_key_values.13.decoder.key', 'past_key_values.13.decoder.value',
'past_key_values.13.encoder.key', 'past_key_values.13.encoder.value',
'past_key_values.14.decoder.key', 'past_key_values.14.decoder.value',
'past_key_values.14.encoder.key', 'past_key_values.14.encoder.value',
'past_key_values.15.decoder.key', 'past_key_values.15.decoder.value',
'past_key_values.15.encoder.key', 'past_key_values.15.encoder.value',
],
output_names=[
'logits',
'present.0.decoder.key', 'present.0.decoder.value',
'present.0.encoder.key', 'present.0.encoder.value',
'present.1.decoder.key', 'present.1.decoder.value',
'present.1.encoder.key', 'present.1.encoder.value',
'present.2.decoder.key', 'present.2.decoder.value',
'present.2.encoder.key', 'present.2.encoder.value',
'present.3.decoder.key', 'present.3.decoder.value',
'present.3.encoder.key', 'present.3.encoder.value',
'present.4.decoder.key', 'present.4.decoder.value',
'present.4.encoder.key', 'present.4.encoder.value',
'present.5.decoder.key', 'present.5.decoder.value',
'present.5.encoder.key', 'present.5.encoder.value',
'present.6.decoder.key', 'present.6.decoder.value',
'present.6.encoder.key', 'present.6.encoder.value',
'present.7.decoder.key', 'present.7.decoder.value',
'present.7.encoder.key', 'present.7.encoder.value',
'present.8.decoder.key', 'present.8.decoder.value',
'present.8.encoder.key', 'present.8.encoder.value',
'present.9.decoder.key', 'present.9.decoder.value',
'present.9.encoder.key', 'present.9.encoder.value',
'present.10.decoder.key', 'present.10.decoder.value',
'present.10.encoder.key', 'present.10.encoder.value',
'present.11.decoder.key', 'present.11.decoder.value',
'present.11.encoder.key', 'present.11.encoder.value',
'present.12.decoder.key', 'present.12.decoder.value',
'present.12.encoder.key', 'present.12.encoder.value',
'present.13.decoder.key', 'present.13.decoder.value',
'present.13.encoder.key', 'present.13.encoder.value',
'present.14.decoder.key', 'present.14.decoder.value',
'present.14.encoder.key', 'present.14.encoder.value',
'present.15.decoder.key', 'present.15.decoder.value',
'present.15.encoder.key', 'present.15.encoder.value',
],
dynamic_axes={
'input_ids': [0, 1],
'inputs_embeds': [0, 1],
'logits': [0, 1],
'position_ids': [0, 1],
'attention_mask': [0, 1],
'past_key_values.0.decoder.key': [0, 2],
'past_key_values.0.decoder.value': [0, 2],
'past_key_values.0.encoder.key': [0, 2],
'past_key_values.0.encoder.value': [0, 2],
'past_key_values.1.decoder.key': [0, 2],
'past_key_values.1.decoder.value': [0, 2],
'past_key_values.1.encoder.key': [0, 2],
'past_key_values.1.encoder.value': [0, 2],
'past_key_values.2.decoder.key': [0, 2],
'past_key_values.2.decoder.value': [0, 2],
'past_key_values.2.encoder.key': [0, 2],
'past_key_values.2.encoder.value': [0, 2],
'past_key_values.3.decoder.key': [0, 2],
'past_key_values.3.decoder.value': [0, 2],
'past_key_values.3.encoder.key': [0, 2],
'past_key_values.3.encoder.value': [0, 2],
'past_key_values.4.decoder.key': [0, 2],
'past_key_values.4.decoder.value': [0, 2],
'past_key_values.4.encoder.key': [0, 2],
'past_key_values.4.encoder.value': [0, 2],
'past_key_values.5.decoder.key': [0, 2],
'past_key_values.5.decoder.value': [0, 2],
'past_key_values.5.encoder.key': [0, 2],
'past_key_values.5.encoder.value': [0, 2],
'past_key_values.6.decoder.key': [0, 2],
'past_key_values.6.decoder.value': [0, 2],
'past_key_values.6.encoder.key': [0, 2],
'past_key_values.6.encoder.value': [0, 2],
'past_key_values.7.decoder.key': [0, 2],
'past_key_values.7.decoder.value': [0, 2],
'past_key_values.7.encoder.key': [0, 2],
'past_key_values.7.encoder.value': [0, 2],
'past_key_values.8.decoder.key': [0, 2],
'past_key_values.8.decoder.value': [0, 2],
'past_key_values.8.encoder.key': [0, 2],
'past_key_values.8.encoder.value': [0, 2],
'past_key_values.9.decoder.key': [0, 2],
'past_key_values.9.decoder.value': [0, 2],
'past_key_values.9.encoder.key': [0, 2],
'past_key_values.9.encoder.value': [0, 2],
'past_key_values.10.decoder.key': [0, 2],
'past_key_values.10.decoder.value': [0, 2],
'past_key_values.10.encoder.key': [0, 2],
'past_key_values.10.encoder.value': [0, 2],
'past_key_values.11.decoder.key': [0, 2],
'past_key_values.11.decoder.value': [0, 2],
'past_key_values.11.encoder.key': [0, 2],
'past_key_values.11.encoder.value': [0, 2],
'past_key_values.12.decoder.key': [0, 2],
'past_key_values.12.decoder.value': [0, 2],
'past_key_values.12.encoder.key': [0, 2],
'past_key_values.12.encoder.value': [0, 2],
'past_key_values.13.decoder.key': [0, 2],
'past_key_values.13.decoder.value': [0, 2],
'past_key_values.13.encoder.key': [0, 2],
'past_key_values.13.encoder.value': [0, 2],
'past_key_values.14.decoder.key': [0, 2],
'past_key_values.14.decoder.value': [0, 2],
'past_key_values.14.encoder.key': [0, 2],
'past_key_values.14.encoder.value': [0, 2],
'past_key_values.15.decoder.key': [0, 2],
'past_key_values.15.decoder.value': [0, 2],
'past_key_values.15.encoder.key': [0, 2],
'past_key_values.15.encoder.value': [0, 2],
'present.0.decoder.key': [0, 2],
'present.0.decoder.value': [0, 2],
'present.0.encoder.key': [0, 2],
'present.0.encoder.value': [0, 2],
'present.1.decoder.key': [0, 2],
'present.1.decoder.value': [0, 2],
'present.1.encoder.key': [0, 2],
'present.1.encoder.value': [0, 2],
'present.2.decoder.key': [0, 2],
'present.2.decoder.value': [0, 2],
'present.2.encoder.key': [0, 2],
'present.2.encoder.value': [0, 2],
'present.3.decoder.key': [0, 2],
'present.3.decoder.value': [0, 2],
'present.3.encoder.key': [0, 2],
'present.3.encoder.value': [0, 2],
'present.4.decoder.key': [0, 2],
'present.4.decoder.value': [0, 2],
'present.4.encoder.key': [0, 2],
'present.4.encoder.value': [0, 2],
'present.5.decoder.key': [0, 2],
'present.5.decoder.value': [0, 2],
'present.5.encoder.key': [0, 2],
'present.5.encoder.value': [0, 2],
'present.6.decoder.key': [0, 2],
'present.6.decoder.value': [0, 2],
'present.6.encoder.key': [0, 2],
'present.6.encoder.value': [0, 2],
'present.7.decoder.key': [0, 2],
'present.7.decoder.value': [0, 2],
'present.7.encoder.key': [0, 2],
'present.7.encoder.value': [0, 2],
'present.8.decoder.key': [0, 2],
'present.8.decoder.value': [0, 2],
'present.8.encoder.key': [0, 2],
'present.8.encoder.value': [0, 2],
'present.9.decoder.key': [0, 2],
'present.9.decoder.value': [0, 2],
'present.9.encoder.key': [0, 2],
'present.9.encoder.value': [0, 2],
'present.10.decoder.key': [0, 2],
'present.10.decoder.value': [0, 2],
'present.10.encoder.key': [0, 2],
'present.10.encoder.value': [0, 2],
'present.11.decoder.key': [0, 2],
'present.11.decoder.value': [0, 2],
'present.11.encoder.key': [0, 2],
'present.11.encoder.value': [0, 2],
'present.12.decoder.key': [0, 2],
'present.12.decoder.value': [0, 2],
'present.12.encoder.key': [0, 2],
'present.12.encoder.value': [0, 2],
'present.13.decoder.key': [0, 2],
'present.13.decoder.value': [0, 2],
'present.13.encoder.key': [0, 2],
'present.13.encoder.value': [0, 2],
'present.14.decoder.key': [0, 2],
'present.14.decoder.value': [0, 2],
'present.14.encoder.key': [0, 2],
'present.14.encoder.value': [0, 2],
'present.15.decoder.key': [0, 2],
'present.15.decoder.value': [0, 2],
'present.15.encoder.key': [0, 2],
'present.15.encoder.value': [0, 2],
},
verbose=False, opset_version=14
)
print("<------")
exit()
encode_imagesエクスポート
〇 LLaVA/llava/model/llava_arch.py
class LlavaMetaForCausalLM(ABC):
...
def prepare_inputs_labels_for_multimodal(
...
):
...
if type(images) is list or images.ndim == 5:
...
else:
image_features = self.encode_images(images)
↓
class Exp(nn.Module):
def __init__(self, model):
super().__init__()
self.vision_tower = model.get_vision_tower()
self.mm_projector = model.mm_projector
def forward(self, images):
image_features = self.vision_tower(images)
image_features = self.mm_projector(image_features)
return image_features
class LlavaMetaForCausalLM(ABC):
...
def prepare_inputs_labels_for_multimodal(
...
):
...
if type(images) is list or images.ndim == 5:
...
else:
net = Exp(self.get_model())
image_features = net(images)
print("------>")
from torch.autograd import Variable
x = Variable(images)
torch.onnx.export(
net, x, 'encode_images.onnx',
input_names=["images"],
output_names=["image_features"],
dynamic_axes={'images': [0], 'image_features': [0]},
verbose=False, opset_version=14
)
print("<------")
exit()
embed_tokensエクスポート
〇 LLaVA/llava/model/llava_arch.py
class LlavaMetaForCausalLM(ABC):
...
def prepare_inputs_labels_for_multimodal(
...
):
...
for batch_idx, cur_input_ids in enumerate(input_ids):
...
cur_input_embeds = self.get_model().embed_tokens(
torch.cat(cur_input_ids_noim)
)
↓
class LlavaMetaForCausalLM(ABC):
...
def prepare_inputs_labels_for_multimodal(
...
):
...
for batch_idx, cur_input_ids in enumerate(input_ids):
...
print("------>")
from torch.autograd import Variable
x = Variable(torch.cat(cur_input_ids_noim))
torch.onnx.export(
self.get_model().embed_tokens, x, 'embed_tokens.onnx',
input_names=["input_ids"],
output_names=["embeds"],
dynamic_axes={'input_ids': [0], 'embeds': [0]},
verbose=False, opset_version=14
)
print("<------")
exit()
https://github.com/haotian-liu/LLaVA apache