Open Nullkooland opened 3 months ago
I checked the output of model_exported.graph.print_tabular()
:
opcode name target args kwargs
------------- -------------------------- ------------------------------- ---------------------------------------------------------- -------------------------------------------------------------------------------
placeholder p_h_0_ln_1_weight p_h_0_ln_1_weight () {}
placeholder p_h_0_ln_1_bias p_h_0_ln_1_bias () {}
placeholder p_h_0_attn_q_proj_weight p_h_0_attn_q_proj_weight () {}
placeholder p_h_0_attn_k_proj_weight p_h_0_attn_k_proj_weight () {}
placeholder p_h_0_attn_v_proj_weight p_h_0_attn_v_proj_weight () {}
placeholder p_h_0_attn_out_proj_weight p_h_0_attn_out_proj_weight () {}
placeholder p_h_0_mlp_fc_in_weight p_h_0_mlp_fc_in_weight () {}
placeholder p_h_0_mlp_fc_in_bias p_h_0_mlp_fc_in_bias () {}
placeholder p_h_0_mlp_fc_out_weight p_h_0_mlp_fc_out_weight () {}
placeholder p_h_0_mlp_fc_out_bias p_h_0_mlp_fc_out_bias () {}
placeholder p_ln_f_weight p_ln_f_weight () {}
placeholder p_ln_f_bias p_ln_f_bias () {}
placeholder c_h_0_attn_embed_positions c_h_0_attn_embed_positions () {}
placeholder b_h_0_attn_bias b_h_0_attn_bias () {}
placeholder c_h_0_attn_scale_attn c_h_0_attn_scale_attn () {}
placeholder input_ids input_ids () {}
placeholder past_key_values past_key_values () {}
placeholder c_h_0_attn_lifted_tensor_0 c_h_0_attn_lifted_tensor_0 () {}
placeholder attention_mask attention_mask () {}
placeholder token_type_ids token_type_ids () {}
placeholder position_ids position_ids () {}
placeholder head_mask head_mask () {}
placeholder inputs_embeds inputs_embeds () {}
placeholder use_cache use_cache () {}
placeholder output_attentions output_attentions () {}
placeholder output_hidden_states output_hidden_states () {}
placeholder return_dict return_dict () {}
...
and inspect the _flat_input_args
passed to XLAExportInterpreter
:
it looks like the lifted constant tensor arg c_h_0_attn_lifted_tensor_0
has a wrong position in the graph input signature, mismatched with its expected pos in _flat_input_args
.
🐛 Bug
The issue looks related to lifted constants during
torch.export
, I found a commit https://github.com/pytorch/xla/commit/d8d7e58b78664aff2713e5f25adb3d61c42d44e7 might be related, but these code does not exist inv2.4.0
version of torch_xla.To Reproduce
Environment