iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.83k stars 611 forks source link

Unable to get GPT-2 TF lite model compiling and running on IREE #13161

Open dellis23 opened 1 year ago

dellis23 commented 1 year ago

I am attempting to run the TF lite GPT2 model on IREE (https://huggingface.co/gpt2). I've tried a couple of approaches and hit dead ends on both.

Approach 1 – load the provided 64.tflite model directly

Steps taken:

  1. Use iree-import-tflite to convert 64.tflite to mlir bytecode
  2. Compile the mlir bytecode to IREE vmfb
  3. Attempt to run the vmfb via iree-run-module

I hit a snag on step 3, since I didn't know the name of the function to run. I first tried using iree-dump-module to inspect the available functions, but it gave no output.

$ bazel run //tools:iree-dump-module ~/Downloads/gpt2/64.vmfb
INFO: Analyzed target //tools:iree-dump-module (0 packages loaded, 0 targets configured).
INFO: Found 1 target...
Target //tools:iree-dump-module up-to-date:
  bazel-bin/tools/iree-dump-module
INFO: Elapsed time: 0.574s, Critical Path: 0.01s
INFO: 1 process: 1 internal.
INFO: Build completed successfully, 1 total action
INFO: Running command line: bazel-bin/tools/iree-dump-module /usr/local/google/home/danielelli
INFO: Build completed successfully, 1 total action

Building and running the tool directly caused it to segfault:

$ ./bazel-bin/tools/iree-dump-module /usr/local/google/home/danielellis/Downloads/gpt2/64.vmfb
Segmentation fault

I then tried importing the keras version of the model (tf_model.h5) in order to inspect it to find a list of the exposed functions, but it failed:

ValueError: No model config found in the file at <tensorflow.python.platform.gfile.GFile object at 0x7fc2b3c09cc0>.

I also tried inspecting the available signatures in TF lite directly:

In [9]: interpreter.get_signature_runner()
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[9], line 1
----> 1 interpreter.get_signature_runner()

File ~/venv.gpt2/lib/python3.10/site-packages/tensorflow/lite/python/interpreter.py:829, in Interpreter.get_signature_runner(self, signature_key)
    827 if signature_key is None:
    828   if len(self._signature_defs) != 1:
--> 829     raise ValueError(
    830         'SignatureDef signature_key is None and model has {0} Signatures. '
    831         'None is only allowed when the model has 1 SignatureDef'.format(
    832             len(self._signature_defs)))
    833   else:
    834     signature_key = next(iter(self._signature_defs))

ValueError: SignatureDef signature_key is None and model has 0 Signatures. None is only allowed when the model has 1 SignatureDef

At this point, I realized the model is probably just serialized weights that still require the original model to run, so I tried pivoting to a different method.

Approach 2 – convert the full model to tflite and then run on IREE

I found a notebook where someone converted the full model to TF lite, so I decided to try a similar approach.

  1. Convert the full model to tflite via the same method as the aforementioned notebook
  2. Use iree-import-tflite to convert the new tflite model to mlir bytecode
  3. Compile the mlir bytecode to IREE vmfb
  4. Attempt to run the vmfb via iree-run-module

After converting to TF lite (step 1), I could now see an exposed function signature:

In [24]: int2 =  tf.lite.Interpreter('gpt2-converted.tflite')

In [25]: int2.get_signature_runner()
Out[25]: <tensorflow.lite.python.interpreter.SignatureRunner at 0x7fc274d64220>

In [28]: int2.get_signature_list()
Out[28]:
{'serving_default': {'inputs': ['attention_mask', 'input_ids'],
  'outputs': ['logits',
   'past_key_values_1',
   'past_key_values_10',
   'past_key_values_11',
   'past_key_values_12',
   'past_key_values_2',
   'past_key_values_3',
   'past_key_values_4',
   'past_key_values_5',
   'past_key_values_6',
   'past_key_values_7',
   'past_key_values_8',
   'past_key_values_9']}}

However, when attempting to compile to MLIR (step 3), I got the following error:

Traceback (most recent call last):
  File "/usr/local/google/home/danielellis/Downloads/gpt2/main.py", line 25, in <module>
    main()
  File "/usr/local/google/home/danielellis/Downloads/gpt2/main.py", line 10, in main
    vmfb = ireec.compile_str(mlir_bytecode,
  File "/usr/local/google/home/danielellis/venv.gpt2/lib/python3.10/site-packages/iree/compiler/tools/core.py", line 278, in compile_str
    result = invoke_immediate(cl, immediate_input=input_bytes)
  File "/usr/local/google/home/danielellis/venv.gpt2/lib/python3.10/site-packages/iree/compiler/tools/binaries.py", line 196, in invoke_immediate
    raise CompilerToolError(process)
iree.compiler.tools.binaries.CompilerToolError: Error invoking IREE compiler tool iree-compile
Diagnostics:
<stdin>:0:0: error: outer module does not contain a vm.module op
error opening input file: failed to generate bytecode

Invoked with:
 iree-compile /usr/local/google/home/danielellis/venv.gpt2/lib/python3.10/site-packages/iree/compiler/tools/../_mlir_libs/iree-compile - --iree-input-type=tosa --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=llvm-cpu --iree-llvm-embedded-linker-path=/usr/local/google/home/danielellis/venv.gpt2/lib/python3.10/site-packages/iree/compiler/tools/../_mlir_libs/iree-lld --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false

Need more information? Set IREE_SAVE_TEMPS=/some/dir in your environment to save all artifacts and reproducers.

Discussion

What's the recommended flow here? I'm not sure there's a bug here, but I also don't know what I should be doing if there is no bug. Are there tooling or error message improvements we could make to make it clear to users how they can get around this?

It's interesting to me that compilation worked on the mlir bytecode that is probably just weights (though I haven't been able to confirm this) but failed on the one that should have actually contained an entry function for lack of a vm.module op.

ScottTodd commented 1 year ago

For another possible path (if you just want some GPT-2 and care less about the specifics of the workflow), @rsuderman set up https://github.com/iree-org/iree-jax/tree/main/models/gpt2, which is tested continuously: https://github.com/iree-org/iree-jax/actions/workflows/test_gpt2_model.yaml

dellis23 commented 1 year ago

Thanks. In this case, though we do care about GPT-2 and have picked it for its popularity, it's more about exercising the various workflows and making them as smooth as possible and identifying any bugs that arise. Right now I'm going through the TF lite path, since it's my understanding one of the things users might want to do is test their existing TF lite models on IREE.

dellis23 commented 1 year ago

@jpienaar

TL;DR: This tflite model doesn't convert with experimental_tflite_to_tosa_bytecode and doesn't throw an error. There's no output or written file.

After experimenting with the converted model, I quickly realized we weren't getting back predicted tokens, but rather, an embedding for the passed in text values. To get token predictions, I need to use TFGPT2LMHeadModel. To generate a tflite model for this model, I did the following:

from transformers import TFGPT2LMHeadModel
from transformers import GPT2Tokenizer
import tensorflow as tf

# Instantiate the model
model = TFGPT2LMHeadModel.from_pretrained('gpt2')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
text = "Replace me by any text you'd like."
input_ids = tokenizer.encode("Hello, my dog is cute", return_tensors='tf')
output = model.generate(input_ids=input_ids)
print(tokenizer.decode(output[0]))

# Convert to TF Lite
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
export_file = "gpt2-head-model.tflite"
open(export_file, "wb").write(tflite_model)

I've uploaded the converted model here. I see a few places in ExperimentalTFLiteToTosaBytecode where we bail with an error. Not sure if one of those is being hit and the error not being bubbled up or what.

dellis23 commented 1 year ago

@jpienaar

The converter appears to be failing here. Since no error message is being bubbled up (and it's not erroring out at all, actually), I tried printing the TF status message here; it's blank.

Debugging the pass manager shows two failing passes: TosaLegalizeTFLPass and mlir::detail::OpToOpPassAdaptor. The latter appears to be a generic piece of llvm mlir infrastructure, so I don't think it's the main issue.

For the failure during TosaLegalizeTFLPass, the failure is happening in ConvertTFLGatherOp; there are a little over a dozen of these failures. This calls out to convertGatherOp. Interestingly, the failures here all seem to contain an error message, so we should probably figure out why these are not being bubbled up (and why TF doesn't seem to think there's an error to throw at all).

I will keep digging into why this is failing, but if you have any initial insights or gut instincts that might help, let me know.

jpienaar commented 1 year ago

I'm guessing it is a missing installation of the error context (StatusScopedDiagnosticHandler is commonly used). Note: if the error is via "notifyMatchFailure" then those aren't reported except when running with --debug (and not in opt mode), but the error should propagated. So there could be missing pass failure.

dellis23 commented 1 year ago

It looks like the failure is happening here, because of dynamic shapes. Specifically, these shapes are pushed during this indices loop.

Given that we're explicitly bailing here, I'm guessing this model is not compatible? From a user's point of view, is there something that could be done to fix the model to make it compatible? It's also worth noting that even if we bubbled up the error message here, it probably wouldn't be meaningful enough to the average user anyway ("multiply dynamic shapes when reshaping tosa.gather result up"); perhaps they aren't the target audience though.

jpienaar commented 1 year ago

Yes exactly the error messages in pattern failure are more useful for devs than users.

TOSA has limited dynamic shape support - but we have a verify converted pass that should be returning at least a failure.

The user here could fix this by setting the input shapes AFAIK. That would then get propagated during conversion.