Closed kyakuno closed 4 months ago
@ooe1123 他のモデルが終わった後に、こちらをお願いできると嬉しいです
〇 transformers/models/whisper/modeling_whisper.py
class WhisperSdpaAttention(WhisperAttention):
...
def forward(
self,
...
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
...
if (
is_cross_attention
and past_key_value is not None
and past_key_value[0].shape[2] == key_value_states.shape[1]
):
...
elif is_cross_attention:
...
elif past_key_value is not None:
...
else:
...
...
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
dropout_p=self.dropout if self.training else 0.0,
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
is_causal=self.is_causal and attention_mask is None and tgt_len > 1,
)
class WhisperDecoder(WhisperPreTrainedModel):
...
def forward(
...
):
...
if self._use_flash_attention_2:
...
elif self._use_sdpa and head_mask is None and not output_attentions:
# output_attentions=True & head_mask can not be supported when using SDPA.
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask, input_shape, inputs_embeds, past_key_values_length
)
else:
...
↓
class WhisperSdpaAttention(WhisperAttention):
...
def forward(
self,
...
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
...
if is_cross_attention:
key_states = torch.cat([past_key_value[0], self._shape(self.k_proj(key_value_states), -1, bsz)], dim=2)
value_states = torch.cat([past_key_value[1], self._shape(self.v_proj(key_value_states), -1, bsz)], dim=2)
key_states = key_states[:,:,:1500,:]
value_states = value_states[:,:,:1500,:]
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
else:
# self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
...
if torch.onnx.is_in_onnx_export():
if self.is_causal:
attn_output_1 = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
dropout_p=self.dropout if self.training else 0.0,
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
is_causal=False
)
attn_output_2 = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
dropout_p=self.dropout if self.training else 0.0,
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
is_causal=True
)
ind = torch.gt(tgt_len, 1).type(torch.int64)
sel = torch.stack([attn_output_1, attn_output_2])
attn_output = sel[ind]
else:
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
dropout_p=self.dropout if self.training else 0.0,
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
is_causal=False,
)
else:
# オリジナル実装
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
dropout_p=self.dropout if self.training else 0.0,
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
is_causal=self.is_causal and attention_mask is None and tgt_len > 1,
)
class WhisperDecoder(WhisperPreTrainedModel):
...
def forward(
...
):
...
if self._use_flash_attention_2:
...
elif self._use_sdpa and head_mask is None and not output_attentions:
# output_attentions=True & head_mask can not be supported when using SDPA.
# attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
# attention_mask, input_shape, inputs_embeds, past_key_values_length
# )
attention_mask = None
else:
...
〇 transformers/generation/utils.py
class GenerationMixin:
...
def _greedy_search(
...
) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
...
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
# prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
# forward pass to get next token
outputs = self(
**model_inputs,
return_dict=True,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
↓
class GenerationMixin:
...
def _greedy_search(
...
) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
...
if 1:
class Net(nn.Module):
def __init__(self, net):
super(Net, self).__init__()
self.net = net
def forward(
self, decoder_input_ids, encoder_hidden_states,
past_key_values_0_decoder_key, past_key_values_0_decoder_value, past_key_values_0_encoder_key, past_key_values_0_encoder_value, past_key_values_1_decoder_key, past_key_values_1_decoder_value, past_key_values_1_encoder_key, past_key_values_1_encoder_value,
):
model_inputs = {
"decoder_input_ids": decoder_input_ids,
"encoder_outputs": [encoder_hidden_states],
"past_key_values": [
[
past_key_values_0_decoder_key,
past_key_values_0_decoder_value,
past_key_values_0_encoder_key,
past_key_values_0_encoder_value,
],
[
past_key_values_1_decoder_key,
past_key_values_1_decoder_value,
past_key_values_1_encoder_key,
past_key_values_1_encoder_value,
],
],
}
outputs = self.net(
**model_inputs,
return_dict=True,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
# return outputs # Updated
return (
outputs["logits"],
outputs["past_key_values"][0][0].type(torch.float16),
outputs["past_key_values"][0][1].type(torch.float16),
outputs["past_key_values"][0][2].type(torch.float16),
outputs["past_key_values"][0][3].type(torch.float16),
outputs["past_key_values"][1][0].type(torch.float16),
outputs["past_key_values"][1][1].type(torch.float16),
outputs["past_key_values"][1][2].type(torch.float16),
outputs["past_key_values"][1][3].type(torch.float16),
)
model = Net(self)
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
# prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
# Add
if model_inputs["past_key_values"] is None:
b = model_inputs["encoder_outputs"][0].size(0)
d = model_inputs["encoder_outputs"][0].device
model_inputs["past_key_values"] = [
[
torch.zeros(b, 20, 0, 64, dtype=torch.float16).to(d),
torch.zeros(b, 20, 0, 64, dtype=torch.float16).to(d),
torch.zeros(b, 20, 0, 64, dtype=torch.float16).to(d),
torch.zeros(b, 20, 0, 64, dtype=torch.float16).to(d),
]
] * 2
if 1 and 0 < model_inputs["past_key_values"][0][0].size(2):
print("------>")
from torch.autograd import Variable
xx = (
Variable(model_inputs["decoder_input_ids"]),
Variable(model_inputs["encoder_outputs"].last_hidden_state),
Variable(model_inputs["past_key_values"][0][0]),
Variable(model_inputs["past_key_values"][0][1]),
Variable(model_inputs["past_key_values"][0][2]),
Variable(model_inputs["past_key_values"][0][3]),
Variable(model_inputs["past_key_values"][1][0]),
Variable(model_inputs["past_key_values"][1][1]),
Variable(model_inputs["past_key_values"][1][2]),
Variable(model_inputs["past_key_values"][1][3]),
)
torch.onnx.export(
model, xx, 'decoder_model.onnx',
input_names=[
'input_ids', 'encoder_hidden_states', 'past_key_values.0.decoder.key', 'past_key_values.0.decoder.value', 'past_key_values.0.encoder.key', 'past_key_values.0.encoder.value', 'past_key_values.1.decoder.key', 'past_key_values.1.decoder.value', 'past_key_values.1.encoder.key', 'past_key_values.1.encoder.value',
],
output_names=[
'logits',
'present.0.decoder.key', 'present.0.decoder.value', 'present.0.encoder.key', 'present.0.encoder.value', 'present.1.decoder.key', 'present.1.decoder.value', 'present.1.encoder.key', 'present.1.encoder.value',
],
dynamic_axes={
'input_ids': {0: 'batch_size', 1: 'decoder_sequence_length'},
'encoder_hidden_states': {0: 'batch_size', 1: 'encoder_sequence_length / 2'},
'logits': {0: 'batch_size', 1: 'decoder_sequence_length'},
'past_key_values.0.decoder.key': {0: 'batch_size', 2: 'past_decoder_sequence_length'},
'past_key_values.0.decoder.value': {0: 'batch_size', 2: 'past_decoder_sequence_length'},
'past_key_values.0.encoder.key': {0: 'batch_size', 2: 'past_decoder_sequence_length'},
'past_key_values.0.encoder.value': {0: 'batch_size', 2: 'past_decoder_sequence_length'},
'past_key_values.1.decoder.key': {0: 'batch_size', 2: 'past_decoder_sequence_length'},
'past_key_values.1.decoder.value': {0: 'batch_size', 2: 'past_decoder_sequence_length'},
'past_key_values.1.encoder.key': {0: 'batch_size', 2: 'past_decoder_sequence_length'},
'past_key_values.1.encoder.value': {0: 'batch_size', 2: 'past_decoder_sequence_length'},
'present.0.decoder.key': {0: 'batch_size', 2: 'past_decoder_sequence_length + 1'},
'present.0.decoder.value': {0: 'batch_size', 2: 'past_decoder_sequence_length + 1'},
'present.0.encoder.key': {0: 'batch_size', 2: 'encoder_sequence_length_out'},
'present.0.encoder.value': {0: 'batch_size', 2: 'encoder_sequence_length_out'},
'present.1.decoder.key': {0: 'batch_size', 2: 'past_decoder_sequence_length + 1'},
'present.1.decoder.value': {0: 'batch_size', 2: 'past_decoder_sequence_length + 1'},
'present.1.encoder.key': {0: 'batch_size', 2: 'encoder_sequence_length_out'},
'present.1.encoder.value': {0: 'batch_size', 2: 'encoder_sequence_length_out'},
},
verbose=False, opset_version=14
)
print("<------")
exit(0)
opset=17でエクスポートした場合、以下のエラーが発生するので、その対応
onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Load model from kotoba-whisper-v1.0_decoder.onnx failed:Type Error: Type parameter (T) of Optype (LayerNormalization) bound to different types (tensor(float) and tensor(float16) in node (/net/model/decoder/layers.0/encoder_attn_layer_norm/LayerNormalization).
〇 transformers/models/whisper/modeling_whisper.py
class WhisperDecoderLayer(nn.Module):
...
def forward(
...
) -> torch.Tensor:
...
if encoder_hidden_states is not None:
residual = hidden_states
hidden_states = self.encoder_attn_layer_norm(hidden_states)
↓
class WhisperDecoderLayer(nn.Module):
...
def forward(
...
) -> torch.Tensor:
...
if encoder_hidden_states is not None:
residual = hidden_states
hidden_states = hidden_states.type(torch.float16)
hidden_states = self.encoder_attn_layer_norm(hidden_states)
日本語特化のwhisperモデル。 https://huggingface.co/kotoba-tech/kotoba-whisper-v1.0