google-coral / tflite

Examples using TensorFlow Lite API to run inference on Coral devices
https://coral.withgoogle.com
Apache License 2.0
182 stars 68 forks source link

Encoding/decoding NLP model in tensorflow lite (fine-tuned GPT2) #47

Closed Guillaume-slize closed 3 years ago

Guillaume-slize commented 3 years ago

We are in the process of building a small virtual assistant and would like it to be able to run a fine-tuned version of GPT-2 on a raspberry-pi with a coral accelerator.

So far, we managed to convert our model to a tflite and to get first results. We know how to convert from words to indices with the previous tokenizer but then we need a bigger tensor as input to the interpreter. We miss the conversion from indices to tensors. Is there a way to do this simply?

You can find our pseudo-code here, we are stuck at step 2 and 6 :

import tensorflow as tf

#Prelude
TF_MODEL_PATH_LITE = "/path/model.tflite"

interpreter = tf.lite.Interpreter(model_path=TF_MODEL_PATH_LITE)
interpreter.allocate_tensors()
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
input_shape = input_details[0]['shape']

#1-Encode input, giving you indices
context_idx = tokenizer.encode("Hello world.", return_tensors = "tf")

#2-How to convert the context_idx to appropriate np.array ?
input_data = np.array(np.random.random_sample(input_shape), dtype=np.int32) #dummy input for now

#3- feed input
interpreter.set_tensor(input_details[0]['index'], input_data)

#4- Run model
interpreter.invoke()

#5- Get output as tensor
output_data = interpreter.get_tensor(output_details[0]['index'])

#6- How decode this np array to idx
output_idx=np.random.randint(100) #dummy for now ...

#7- Decode Output from idx to word
string_tf = tokenizer.decode(output_idx, skip_special_tokens=True)
Naveen-Dodda commented 3 years ago

@Guillaume-slize We are not exactly sure how to help with that. This repository is related to google coral. Can you redirect your question to tflite-support

Guillaume-slize commented 3 years ago

Moved this here: https://github.com/tensorflow/tflite-support/issues/549#issue-912736942