llvm / torch-mlir

The Torch-MLIR project aims to provide first class support from the PyTorch ecosystem to the MLIR ecosystem.
Other
1.3k stars 477 forks source link

Issue with lowering GPT2 to linalgIR #3649

Closed sdalvi-quic closed 5 days ago

sdalvi-quic commented 3 weeks ago

I am trying to lower GPT2 model to linalgIR but I am running into errors. I have built torch_mlir from source and have installed transformers with the latest version: pip install git+https://github.com/huggingface/transformers.

The test case I am running is:

import sys
from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Model
import torch
import torch_mlir
from torch_mlir import torchscript
import re

test_modelname = "gpt2"
tokenizer = GPT2Tokenizer.from_pretrained(test_modelname)
prompt = "What is nature of our existence?"
encoding = tokenizer(prompt, return_tensors="pt")

model = GPT2LMHeadModel.from_pretrained(
    test_modelname,
    num_labels=2,
    output_attentions=False,
    output_hidden_states=False,
    torchscript=True,
    attn_implementation="eager",
)

model.to("cpu")
model.eval()

model_response = model.generate(
    encoding["input_ids"].cpu(),
    do_sample=True,
    top_k=50,
    max_length=100,
    top_p=0.95,
    temperature=1.0,
)
print("Prompt:", prompt)
print("Response:", tokenizer.decode(model_response[0]).encode("utf-8"))
print("Input:", encoding["input_ids"].cpu())
print("Output:", model(encoding["input_ids"].cpu()))
print("tokensizer", encoding)

out = torchscript.compile(
    model, 
    [encoding["input_ids"]], #[encoding["attention_mask"]], 
    output_type="linalg-on-tensors",
    use_tracing=True, # Enable tracing
)

The error that I am facing on enabling tracing i.e torch.jit.trace() is:

Traceback (most recent call last):
  File "/local/mnt/workspace/sdalvi/ai-tools/triton_test/benchmarks/triton/LLM/gpt2/gpt.py", line 48, in <module>
    out = torchscript.compile(
          ^^^^^^^^^^^^^^^^^^^^
  File "/local/mnt/workspace/sdalvi/ai-tools/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir/torchscript.py", line 401, in compile
    run_pipeline_with_repro_report(
  File "/local/mnt/workspace/sdalvi/ai-tools/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir/compiler_utils.py", line 78, in run_pipeline_with_repro_report
    raise TorchMlirCompilerError(trimmed_message) from None
torch_mlir.compiler_utils.TorchMlirCompilerError: Lowering TorchScript IR -> Torch Backend IR failed with the following diagnostics:
error: unsupported by backend contract: non-value tensor type
note: see current operation: %205 = "torch.copy.to_tensor"(%204) : (!torch.vtensor<[7,2304],f32>) -> !torch.tensor<[7,2304],f32>
note: this is likely due to a missing case in the MaximizeValueSemantics pass

python exception: Failure while executing pass pipeline

On running torch-mlir-opt:

/local/mnt/workspace/sdalvi/miniconda3/envs/torch_mlir/lib/python3.11/site-packages/transformers/pytorch_utils.py:107:0: error: unsupported by backend contract: non-value tensor type
/local/mnt/workspace/sdalvi/miniconda3/envs/torch_mlir/lib/python3.11/site-packages/transformers/pytorch_utils.py:107:0: note: see current operation: %66 = "torch.copy.to_tensor"(%65) : (!torch.vtensor<[7,2304],f32>) -> !torch.tensor<[7,2304],f32>
/local/mnt/workspace/sdalvi/miniconda3/envs/torch_mlir/lib/python3.11/site-packages/transformers/pytorch_utils.py:107:0: note: this is likely due to a missing case in the MaximizeValueSemantics pass

When I tried to use torch-mlir-opt, it points to the line x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight) in the following code :

class Conv1D(nn.Module):
    """
    1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2).

    Basically works like a linear layer but the weights are transposed.

    Args:
        nf (`int`): The number of output features.
        nx (`int`): The number of input features.
    """

    def __init__(self, nf, nx):
        super().__init__()
        self.nf = nf
        self.weight = nn.Parameter(torch.empty(nx, nf))
        self.bias = nn.Parameter(torch.zeros(nf))
        nn.init.normal_(self.weight, std=0.02)

    def forward(self, x:torch.Tensor):
        size_out = x.size()[:-1] + (self.nf,)
        x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
        x = x.view(size_out)
        return x

The error that I am facing on enabling scripting i.e running torch.jit.script is:

raise NotSupportedError(r, "Comprehension ifs are not supported yet")
torch.jit.frontend.NotSupportedError: Comprehension ifs are not supported yet:
  File "/local/mnt/workspace/sdalvi/miniconda3/envs/MLIR_p3.11/lib/python3.11/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 1163

        if not return_dict:
            return tuple(
                v
                for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]

I tried to resolve this error but landed up into another error :

Expected a value of type 'Tensor (inferred)' for argument 'x' but instead found type 'Optional[Tuple[Tensor]]'.
Inferred 'x' to be of type 'Tensor' because it was not annotated with an explicit type.
:
  File "/local/mnt/workspace/sdalvi/miniconda3/envs/MLIR_p3.11/lib/python3.11/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 312
            attention_mask = encoder_attention_mask
        else:
            query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
                                ~~~~~~~~~~~ <--- HERE

        query = self._split_heads(query, self.num_heads, self.head_dim)

I feel both the error with scripting and with tracing points to the similar issue.

How do we resolve it? Since the error points to transformers package in the file transformers/pytorch_utils.py in Conv1D, it is giving error with other LLMs as well.

sdalvi-quic commented 3 weeks ago

@fhossein-quic, @trahman-quic.

sdalvi-quic commented 3 weeks ago

Hi @ramiro050, @ZihengJiang, I see that you have contributed to similar kind of issues before https://github.com/llvm/torch-mlir/issues/2523, (https://github.com/llvm/torch-mlir/issues/1151). I am running into similar issue for gpt2 model. Can you please help me with some pointer?

ramiro050 commented 2 weeks ago

I see that you're using the old TorchScript importer. Have you tried using the FX importer? It should functionalize your model (remove mutation) before the model gets passed to Torch-MLIR. Here is the interface for the FX importer: https://github.com/llvm/torch-mlir/blob/eb7bf78a9c1e250949cf0151628f35fb0ac06903/python/torch_mlir/fx.py#L51

sdalvi-quic commented 2 weeks ago

Thank you @ramiro050 I was able to use FX importer and lower it to linalgIR.