linto-ai / linto-desktoptools-hmg

GUI Tool to create, manage and test Keyword Spotting models using TF 2.0
GNU Affero General Public License v3.0
12 stars 2 forks source link

Evaluation result samples don't match false positive / negatives #11

Closed StuartIanNaylor closed 3 years ago

StuartIanNaylor commented 3 years ago

After you run the evaluation I should of got 100 non-hotwords detected as hotwords & 1180 hotwords detected as non-hotwords.

There was far more (too many to count) than the 100 expected hotwords in the selection window and only 7 non-hotwords of 1180 false positives.

As you go down the list and click play only the top section play audio

Also in the testing section the button to add to set is always greyed and it would be excellent to be able to add to the set. Clicking play does nothing also.

Doesn't support unrolled GRU.

I thought tensorflow-lite supported the full subset now?

https://github.com/tensorflow/tensorflow/issues/34266

StuartIanNaylor commented 3 years ago

Just as a mention of GRU for TFL in https://github.com/google-research/google-research/blob/master/kws_streaming/experiments/kws_experiments_paper_12_labels.md#gru_state

Googleresearch push out to tflite and have a benchmark on android I must admit how it trains seems much different to keras but the code is there and obviously they do run a GRU on TFL even if I am not sure how.

Flexops which has a big update in 2.5 which https://github.com/PINTO0309/TensorflowLite-flexdelegate is going to be publishing as soon as released. Also not sure if native TFL GRU support is in the pipline https://github.com/tensorflow/tensorflow/issues/46390 but with flexops it can be imported from the TF core.

With flexops and 2.5 it does make a true hybrid like a CRNN CCN feeding GRU interesting as performance will be interesting rather than say just a full GRU via flexops.

StuartIanNaylor commented 3 years ago

I hacked a rough from the test.py of the google kws with external state.

import sounddevice as sd
import numpy as np
import timeit
import tensorflow.compat.v1 as tf

# Parameters
debug_time = 0
debug_acc = 0
word_threshold = 10.0
rec_duration = 0.020
sample_rate = 16000
num_channels = 1

# Load the TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path="models2/crnn_state/tflite_stream_state_external/stream_state_external.tflite")
interpreter.allocate_tensors()

# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

print(input_details[0]['shape'])
inputs = []
for s in range(len(input_details)):
  inputs.append(np.zeros(input_details[s]['shape'], dtype=np.float32))

def sd_callback(rec, frames, time, status):

    start = timeit.default_timer()

    # Notify if errors
    if status:
        print('Error:', status)

    rec = np.reshape(rec, (1, 320))

    # Make prediction from model
    interpreter.set_tensor(input_details[0]['index'], rec.astype(np.float32))
    # set input states (index 1...)
    for s in range(1, len(input_details)):
      interpreter.set_tensor(input_details[s]['index'], inputs[s])

    interpreter.invoke()
    output_data = interpreter.get_tensor(output_details[0]['index'])
    # get output states and set it back to input states
    # which will be fed in the next inference cycle
    for s in range(1, len(input_details)):
      # The function `get_tensor()` returns a copy of the tensor data.
      # Use `tensor()` in order to get a pointer to the tensor.
      inputs[s] = interpreter.get_tensor(output_details[s]['index'])

    out_tflite_argmax = np.argmax(output_data)

    if out_tflite_argmax == 2:
        print('raspberry')
        print(output_data[0][2])

    if debug_acc:
        print(out_tflite_argmax)

    if debug_time:
        print(timeit.default_timer() - start)

# Start streaming from microphone
with sd.InputStream(channels=num_channels,
                    samplerate=sample_rate,
                    blocksize=int(sample_rate * rec_duration),
                    callback=sd_callback):
    while True:
        pass