pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.48k stars 480 forks source link

[StableHLO] Failed to export GPT-J model from HF Transformers #7845

Open Nullkooland opened 3 months ago

Nullkooland commented 3 months ago

🐛 Bug

The issue looks related to lifted constants during torch.export, I found a commit https://github.com/pytorch/xla/commit/d8d7e58b78664aff2713e5f25adb3d61c42d44e7 might be related, but these code does not exist in v2.4.0 version of torch_xla.

  File ".../export_gptj.py", line 64, in <module>
    stablehlo_gm = torch_xla.stablehlo.exported_program_to_stablehlo(
  File ".../torch_xla/torch_xla/stablehlo.py", line 618, in exported_program_to_stablehlo
    bundle = _exported_program_to_stablehlo_bundle(exported_model, options)
  File ".../torch_xla/torch_xla/stablehlo.py", line 362, in _exported_program_to_stablehlo_bundle
    res = xla_interpreter.run(*_flat_input_args, enable_io_processing=False)
  File ".../pytorch/torch/fx/interpreter.py", line 146, in run
    self.env[node] = self.run_node(node)
  File ".../torch_xla/torch_xla/stablehlo.py", line 274, in run_node
    self._mark_dynamic(res, dynamic_dims)
  File ".../torch_xla/torch_xla/stablehlo.py", line 234, in _mark_dynamic
    tid = torch_xla._XLAC._xla_get_tensor_id(tensor)
TypeError: _xla_get_tensor_id(): incompatible function arguments. The following argument types are supported:
    1. (arg0: torch.Tensor) -> int

Invoked with: None

While executing %c_h_0_attn_lifted_tensor_0 : [num_users=1] = placeholder[target=c_h_0_attn_lifted_tensor_0]
Original traceback:
None
I0000 00:00:1723540341.666568 1272789 cpu_client.cc:470] TfrtCpuClient destroyed.

To Reproduce

import torch
import torch.export
import torch_xla

from transformers.models.gptj.modeling_gptj import GPTJConfig, GPTJModel

if __name__ == "__main__":
    batch_size = 4
    token_size = 100

    model_config = GPTJConfig(
        num_hidden_layers=1,
        num_attention_heads=16,
        num_key_value_heads=16,
        hidden_size=4096,
        intermediate_size=16384,
        max_position_embeddings=4096,
        vocab_size=50400,
        torch_dtype=torch.float16,
        attn_implementation="eager",
        use_cache=False,
        return_dict=False,
    )

    model = GPTJModel(model_config).to(model_config.torch_dtype).eval()

    # Sample inputs.
    hidden_states = torch.randn(
        size=(batch_size, token_size, model_config.hidden_size),
        dtype=model_config.torch_dtype
    )
    attention_mask = torch.ones(
        size=(batch_size, token_size),
        dtype=model_config.torch_dtype,
    )
    position_ids = torch.arange(
        start=0,
        end=token_size,
        dtype=torch.int64,
    ).expand(size=(batch_size, token_size))

    sample_inputs = (
        None,  # input_ids.
        None,  # past_key_values.
        attention_mask,
        None,  # token_type_ids.
        position_ids,
        None,  # head_mask
        hidden_states,  # inputs_embeds.
        False,  # use_cache.
        False,  # output_attentions.
        False,  # output_hidden_states.
        False,  # return_dict.
    )
    model_exported = torch.export.export(model, sample_inputs)
    model_exported.graph.print_tabular()

    # Export to StableHLO
    options = torch_xla.stablehlo.StableHLOExportOptions(
        include_human_readable_text=True,
        inline_all_constant=True,
        export_weights=True
    )
    stablehlo_gm = torch_xla.stablehlo.exported_program_to_stablehlo(
        model_exported,
        options
    )

    print(stablehlo_gm.get_stablehlo_text())

Environment

Nullkooland commented 3 months ago

I checked the output of model_exported.graph.print_tabular():

opcode         name                        target                           args                                                        kwargs
-------------  --------------------------  -------------------------------  ----------------------------------------------------------  -------------------------------------------------------------------------------
placeholder    p_h_0_ln_1_weight           p_h_0_ln_1_weight                ()                                                          {}
placeholder    p_h_0_ln_1_bias             p_h_0_ln_1_bias                  ()                                                          {}
placeholder    p_h_0_attn_q_proj_weight    p_h_0_attn_q_proj_weight         ()                                                          {}
placeholder    p_h_0_attn_k_proj_weight    p_h_0_attn_k_proj_weight         ()                                                          {}
placeholder    p_h_0_attn_v_proj_weight    p_h_0_attn_v_proj_weight         ()                                                          {}
placeholder    p_h_0_attn_out_proj_weight  p_h_0_attn_out_proj_weight       ()                                                          {}
placeholder    p_h_0_mlp_fc_in_weight      p_h_0_mlp_fc_in_weight           ()                                                          {}
placeholder    p_h_0_mlp_fc_in_bias        p_h_0_mlp_fc_in_bias             ()                                                          {}
placeholder    p_h_0_mlp_fc_out_weight     p_h_0_mlp_fc_out_weight          ()                                                          {}
placeholder    p_h_0_mlp_fc_out_bias       p_h_0_mlp_fc_out_bias            ()                                                          {}
placeholder    p_ln_f_weight               p_ln_f_weight                    ()                                                          {}
placeholder    p_ln_f_bias                 p_ln_f_bias                      ()                                                          {}
placeholder    c_h_0_attn_embed_positions  c_h_0_attn_embed_positions       ()                                                          {}
placeholder    b_h_0_attn_bias             b_h_0_attn_bias                  ()                                                          {}
placeholder    c_h_0_attn_scale_attn       c_h_0_attn_scale_attn            ()                                                          {}
placeholder    input_ids                   input_ids                        ()                                                          {}
placeholder    past_key_values             past_key_values                  ()                                                          {}
placeholder    c_h_0_attn_lifted_tensor_0  c_h_0_attn_lifted_tensor_0       ()                                                          {}
placeholder    attention_mask              attention_mask                   ()                                                          {}
placeholder    token_type_ids              token_type_ids                   ()                                                          {}
placeholder    position_ids                position_ids                     ()                                                          {}
placeholder    head_mask                   head_mask                        ()                                                          {}
placeholder    inputs_embeds               inputs_embeds                    ()                                                          {}
placeholder    use_cache                   use_cache                        ()                                                          {}
placeholder    output_attentions           output_attentions                ()                                                          {}
placeholder    output_hidden_states        output_hidden_states             ()                                                          {}
placeholder    return_dict                 return_dict                      ()                                                          {}
...

and inspect the _flat_input_args passed to XLAExportInterpreter:

flat_input_args

it looks like the lifted constant tensor arg c_h_0_attn_lifted_tensor_0 has a wrong position in the graph input signature, mismatched with its expected pos in _flat_input_args.