tensorflow / tflite-support

TFLite Support is a toolkit that helps users to develop ML and deploy TFLite models onto mobile / ioT devices.
Apache License 2.0
378 stars 128 forks source link

No categories present in the resulting list after calling classifier.classify(), even though 20 categories are expected to be there. #943

Open God-Ra opened 1 year ago

God-Ra commented 1 year ago

I'm trying to classify text into one of 20 categories (20newsgroups dataset). I have a tflite model that I use to make inference. The inference is made with BertNLClassifier.classify() method. The output contains no categories even though it should contain 20 categories with respective scores. Also, there are no log messages coming from the support task library. Is this a bug or am I doing something wrong?

The first thing that I've done is converted the trained model from pytorch to tflite using the optimum-cli tool. The command is: optimum-cli export tflite --model ./test_model_first_bert --task text-classification --sequence_length 512 --quantize int8 tflite-quant/

After that, I added metadata with the code below

from tflite_support.metadata_writers import bert_nl_classifier
from tflite_support.metadata_writers import metadata_info
from tflite_support.metadata_writers import writer_utils

NLClassifierWriter = bert_nl_classifier.MetadataWriter
_MODEL_PATH = "model.tflite"
# Task Library expects label files and vocab files that are in the same formats
# as the ones below.
_LABEL_FILE = "bert-20ng-labels.txt"
_VOCAB_FILE = "vocab-my.txt"
# NLClassifier supports tokenize input string using the regex tokenizer. See
# more details about how to set up RegexTokenizer below:
# https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/metadata/python/metadata_writers/metadata_info.py#L130
_DELIM_REGEX_PATTERN = r"[^\w\']+"
_SAVE_TO_PATH = "model-metadata.tflite"

# Create the metadata writer.
writer = bert_nl_classifier.MetadataWriter.create_for_inference(
    writer_utils.load_file(_MODEL_PATH),
    metadata_info.BertTokenizerMd(_VOCAB_FILE),
    [_LABEL_FILE],
    ids_name="model_input_ids:0",
    mask_name="model_attention_mask:0",
    segment_name="model_token_type_ids:0")

# Verify the metadata generated by metadata writer.
print(writer.get_metadata_json())

# Populate the metadata into the model.
writer_utils.save_file(writer.populate(), _SAVE_TO_PATH)

The vocab file is the same as in mobilebert.tflite(from here) and the label file contains 20 lines with each line specifying a category.

After incorporating metadata, I tried to run inference in android studio with a dummy text that, in theory, should prompt classifier.classify to return 20 categories with respective scores. However, it returned an empty list that contained no results. Here is the code that I used to classify text.

package net.etfbl.mrprojektni;

import androidx.appcompat.app.AppCompatActivity;

import android.app.ActivityManager;
import android.content.res.AssetFileDescriptor;
import android.os.Bundle;
import android.os.Debug;
import android.util.Log;

import org.tensorflow.lite.Interpreter;
import org.tensorflow.lite.support.label.Category;
import org.tensorflow.lite.task.text.nlclassifier.BertNLClassifier;

import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.io.StringWriter;
import java.nio.ByteBuffer;
import java.nio.channels.FileChannel;
import java.util.Arrays;
import java.util.List;

public class MainActivity extends AppCompatActivity {

    @Override
    protected void onCreate(Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);
        setContentView(R.layout.activity_main);

        try {
            nlClassifierTest();
        } catch (Exception e) {
            StringWriter sw = new StringWriter();
            PrintWriter pw = new PrintWriter(sw);
            e.printStackTrace(pw);
        }
    }

    private void nlClassifierTest() throws IOException {

        String tfliteFileName = "model-metadata.tflite";
        BertNLClassifier textClassifier = BertNLClassifier.createFromFile(this, tfliteFileName);

        final String text = "This is a text about baseball.";
        List<Category> results = textClassifier.classify(text);

        String textToShow = "Input: " + text + "\nOutput:\n";
        for (int i = 0; i < results.size(); i++) {
            Category result = results.get(i);
            textToShow +=
                    String.format("    %s: %s\n", result.getLabel(), result.getScore());
        }
        textToShow += "---------\n";

        showResult(textToShow);
    }
    private void showResult(String textToShow) {
            Log.d("tag", textToShow);
        }
}

The output is: image

Dependencies of the android studio project:

    implementation 'org.tensorflow:tensorflow-lite-task-text:0.4.4'
    implementation 'org.tensorflow:tensorflow-lite-gpu-delegate-plugin:0.4.4'
    implementation 'org.tensorflow:tensorflow-lite-gpu:2.9.0'
    implementation 'org.tensorflow:tensorflow-lite:2.13.0'
    implementation 'org.tensorflow:tensorflow-lite-select-tf-ops:2.13.0'

I couldn't upload the model.tflite here so I uploaded it to ufile.io (https://ufile.io/gravczdv).