Ki6an / fastT5

⚡ boost inference speed of T5 models by 5x & reduce the model size by 3x.
Apache License 2.0
565 stars 72 forks source link

flan-t5 support #68

Open loretoparisi opened 1 year ago

loretoparisi commented 1 year ago

I have converted google flan-t5-small using fastT5.export_and_get_onnx_model method with quantization enabled by defaults:

import sys, os, shutil
from transformers import AutoTokenizer

def t5_to_onnx(model_id, output_dir, quantized):
    import fastT5
    model = fastT5.export_and_get_onnx_model(model_id, custom_output_path=output_dir, quantized=quantized)
    return model

def onnx_generate(input, onnx_model, tokenizer):
    token = tokenizer(input, return_tensors='pt')
    tokens = onnx_model.generate(input_ids=token['input_ids'],
                                 attention_mask=token['attention_mask'],
                                 num_beams=2)
    output = tokenizer.decode(tokens.squeeze(), skip_special_tokens=True)
    return output

if __name__ == '__main__':
    args = sys.argv
    model_id = "google/flan-t5-small" # t5-small | t5-large | google/flan-t5-small
    output_dir = "./models"
    quantized = True
    test_input = "translate English to French: The universe is a dark forest."

    if len(args) > 1:
        model_id = args[1]
    if len(args) > 2:
        output_dir = args[2]
    if len(args) > 3:
        quantized = args[3].lower() == "true" or args[3].lower() == "1" or args[3].lower() == "yes"
    if len(args) > 4:
        test_input = args[4]
    model_name = model_id.split("/")[-1]

    print(f"model_name: {model_name}")
    print(f"  model_id: {model_id}")
    print(f"output_dir: {output_dir}")
    print(f" quantized: {quantized}")

    os.makedirs(output_dir, exist_ok=True)

    build_dir = os.path.abspath(f"buildmodel")
    os.makedirs(build_dir, exist_ok=True)

    tokenizer = AutoTokenizer.from_pretrained(model_id)
    tokenizer.save_pretrained(build_dir)
    shutil.copyfile(os.path.join(build_dir, "tokenizer.json"), os.path.join(output_dir, f"{model_name}-tokenizer.json"))

    onnx_model = t5_to_onnx(model_id, build_dir, quantized)
    msuffix = "-quantized" if quantized else ""
    for session in ["encoder", "init-decoder", "decoder"]:
        shutil.copyfile(os.path.join(build_dir, f"{model_name}-{session}{msuffix}.onnx"), os.path.join(output_dir, f"{model_name}-{session}{msuffix}.onnx"))

    test_output = onnx_generate(test_input, onnx_model, tokenizer)
    print(f"> {test_input}")
    print(f"< {test_output}")

getting the quantized onnx models:

-rw-r--r--  1 loretoparisi  staff   55769637 Mar 22 11:44 flan-t5-small-decoder-quantized.onnx
-rw-r--r--  1 loretoparisi  staff   35811788 Mar 22 11:44 flan-t5-small-encoder-quantized.onnx
-rw-r--r--  1 loretoparisi  staff   58984632 Mar 22 11:44 flan-t5-small-init-decoder-quantized.onnx
-rw-r--r--  1 loretoparisi  staff    2422164 Mar 22 11:38 flan-t5-small-tokenizer.json

Anyways when loading the model with a ONNX runtime ort.InferenceSession:

const session = await ort.InferenceSession.create(modelBuffer, { executionProviders: ["wasm"] });

generated tokens look strange.

Using the same process for the t5-small it works fine.

Ki6an commented 1 year ago

can you please provide the reproducible code and output you are getting?

Ki6an commented 1 year ago

does it work with python ort or are you facing the issue only with the js version of ort?

loretoparisi commented 1 year ago

facing

I'm facing this error in transformer-js that is using ONNX converted model here

loretoparisi commented 1 year ago

can you please provide the reproducible code and output you are getting?

Yes, I will fork the original repo and apply the changes.

loretoparisi commented 1 year ago

@Ki6an here my fork where you can try it

This wil install the app and convert flan-t5-small to onnx:

git clone https://github.com/loretoparisi/transformers-js.git
cd transformers-js/
pip install transformers
python tools/convert_model.py 

You will find then the quantized models in the /models folder.

to run it

make demo
make run

now points to http://localhost:8152/?model_id=google/flan-t5-small

The tokenizer code is located here.