yoeo / guesslang

Detect the programming language of a source code
https://guesslang.readthedocs.io
MIT License
773 stars 110 forks source link

add tflite support #52

Open fsx950223 opened 2 years ago

fsx950223 commented 2 years ago
import tensorflow as tf
import numpy as np
import json
from guesslang import model
from guesslang import guess

fp = open('guesslang/data/languages.json', 'r')
j = list(json.load(fp))

model = guess.Guess('./saved_model')
model.export('guesslang/data/model/variables/variables', True)

saved_model = tf.saved_model.load('./saved_model')

inputs = ["""
def qsort(items):
    if not items:
        return []
    else:
        pivot = items[0]
        less = [x for x in items if x <  pivot]
        more = [x for x in items[1:] if x >= pivot]
        return qsort(less) + [pivot] + qsort(more)

if __name__ == '__main__':
    items = [1, 4, 2, 7, 9, 3]
    print(f'Sorted: {qsort(items)}')

"""]
data = tf.strings.bytes_split(inputs)

inputs = data.to_tensor(shape=(1, 10001))
content = tf.convert_to_tensor(inputs[0])
length = tf.cast(data.row_lengths(1)[0], dtype=tf.int32)
predicted = saved_model.signatures['predict'](content=content, length=length)

numpy_floats = predicted['probabilities']
extensions = predicted['all_classes'][0]
ids = predicted['all_class_ids'][0]
idx = tf.argmax(numpy_floats, axis=1)
print(j[ids[idx[0]]])

interpreter = tf.lite.Interpreter('./saved_model/guesslang.tflite')
input_details = interpreter.get_input_details()
interpreter.allocate_tensors()
interpreter.set_tensor(input_details[1]['index'], inputs[0].numpy())
interpreter.set_tensor(input_details[0]['index'], np.array(data.row_lengths(1)[0]).astype(np.int32))

interpreter.invoke()
idx = tf.argmax(interpreter.tensor(interpreter.get_output_details()[1]['index'])(), axis=1)
ids = interpreter.tensor(interpreter.get_output_details()[3]['index'])()[0]
print(j[ids[idx[0]]])

Test script

asiryan commented 1 year ago

Could you share Guesslang *.TFLite model?