Open dellis23 opened 1 year ago
For another possible path (if you just want some GPT-2 and care less about the specifics of the workflow), @rsuderman set up https://github.com/iree-org/iree-jax/tree/main/models/gpt2, which is tested continuously: https://github.com/iree-org/iree-jax/actions/workflows/test_gpt2_model.yaml
Thanks. In this case, though we do care about GPT-2 and have picked it for its popularity, it's more about exercising the various workflows and making them as smooth as possible and identifying any bugs that arise. Right now I'm going through the TF lite path, since it's my understanding one of the things users might want to do is test their existing TF lite models on IREE.
@jpienaar
TL;DR: This tflite model doesn't convert with experimental_tflite_to_tosa_bytecode
and doesn't throw an error. There's no output or written file.
After experimenting with the converted model, I quickly realized we weren't getting back predicted tokens, but rather, an embedding for the passed in text values. To get token predictions, I need to use TFGPT2LMHeadModel
. To generate a tflite model for this model, I did the following:
from transformers import TFGPT2LMHeadModel
from transformers import GPT2Tokenizer
import tensorflow as tf
# Instantiate the model
model = TFGPT2LMHeadModel.from_pretrained('gpt2')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
text = "Replace me by any text you'd like."
input_ids = tokenizer.encode("Hello, my dog is cute", return_tensors='tf')
output = model.generate(input_ids=input_ids)
print(tokenizer.decode(output[0]))
# Convert to TF Lite
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
export_file = "gpt2-head-model.tflite"
open(export_file, "wb").write(tflite_model)
I've uploaded the converted model here. I see a few places in ExperimentalTFLiteToTosaBytecode where we bail with an error. Not sure if one of those is being hit and the error not being bubbled up or what.
@jpienaar
The converter appears to be failing here. Since no error message is being bubbled up (and it's not erroring out at all, actually), I tried printing the TF status message here; it's blank.
Debugging the pass manager shows two failing passes: TosaLegalizeTFLPass
and mlir::detail::OpToOpPassAdaptor
. The latter appears to be a generic piece of llvm mlir infrastructure, so I don't think it's the main issue.
For the failure during TosaLegalizeTFLPass
, the failure is happening in ConvertTFLGatherOp
; there are a little over a dozen of these failures. This calls out to convertGatherOp
. Interestingly, the failures here all seem to contain an error message, so we should probably figure out why these are not being bubbled up (and why TF doesn't seem to think there's an error to throw at all).
I will keep digging into why this is failing, but if you have any initial insights or gut instincts that might help, let me know.
I'm guessing it is a missing installation of the error context (StatusScopedDiagnosticHandler is commonly used). Note: if the error is via "notifyMatchFailure" then those aren't reported except when running with --debug (and not in opt mode), but the error should propagated. So there could be missing pass failure.
It looks like the failure is happening here, because of dynamic shapes. Specifically, these shapes are pushed during this indices loop.
Given that we're explicitly bailing here, I'm guessing this model is not compatible? From a user's point of view, is there something that could be done to fix the model to make it compatible? It's also worth noting that even if we bubbled up the error message here, it probably wouldn't be meaningful enough to the average user anyway ("multiply dynamic shapes when reshaping tosa.gather result up"
); perhaps they aren't the target audience though.
Yes exactly the error messages in pattern failure are more useful for devs than users.
TOSA has limited dynamic shape support - but we have a verify converted pass that should be returning at least a failure.
The user here could fix this by setting the input shapes AFAIK. That would then get propagated during conversion.
I am attempting to run the TF lite GPT2 model on IREE (https://huggingface.co/gpt2). I've tried a couple of approaches and hit dead ends on both.
Approach 1 – load the provided
64.tflite
model directlySteps taken:
iree-import-tflite
to convert64.tflite
to mlir bytecodeiree-run-module
I hit a snag on step 3, since I didn't know the name of the function to run. I first tried using
iree-dump-module
to inspect the available functions, but it gave no output.Building and running the tool directly caused it to segfault:
I then tried importing the keras version of the model (
tf_model.h5
) in order to inspect it to find a list of the exposed functions, but it failed:I also tried inspecting the available signatures in TF lite directly:
At this point, I realized the model is probably just serialized weights that still require the original model to run, so I tried pivoting to a different method.
Approach 2 – convert the full model to tflite and then run on IREE
I found a notebook where someone converted the full model to TF lite, so I decided to try a similar approach.
iree-import-tflite
to convert the new tflite model to mlir bytecodeiree-run-module
After converting to TF lite (step 1), I could now see an exposed function signature:
However, when attempting to compile to MLIR (step 3), I got the following error:
Discussion
What's the recommended flow here? I'm not sure there's a bug here, but I also don't know what I should be doing if there is no bug. Are there tooling or error message improvements we could make to make it clear to users how they can get around this?
It's interesting to me that compilation worked on the mlir bytecode that is probably just weights (though I haven't been able to confirm this) but failed on the one that should have actually contained an entry function for lack of a
vm.module
op.