SaneBow / tflite-kws

Keyword Spotting (KWS) API wrapper for TFLite streaming models.
MIT License
11 stars 2 forks source link

what is right parameters to run kws ds_tc_resnet #1

Open srewai opened 2 years ago

srewai commented 2 years ago

Hi ,

Thanks so much for the wrapper API. I was trying to make real-time KWS spotting using your wrapper and my trained .tflite model using the paper google paper. Below is the flag used for training the model.

As suggested in your readme, I ran the model with

python3 mic_streaming.py -m model path but I get error :

2021-09-13 18:41:05.251578: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2021-09-13 18:41:05.251597: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.
################################################################################
Ctrl-C to exit
################################################################################
From cffi callback <function _StreamBase.__init__.<locals>.callback_ptr at 0x7f76d648e048>:
Traceback (most recent call last):
  File "/lhome/srauniy/test2/sonu-venv/lib/python3.6/site-packages/sounddevice.py", line 840, in callback_ptr
    return _wrap_callback(callback, data, frames, time, status)
  File "/lhome/srauniy/test2/sonu-venv/lib/python3.6/site-packages/sounddevice.py", line 2677, in _wrap_callback
    callback(*args)
  File "mic_streaming.py", line 116, in callback
    kw = gkws.process(indata)
  File "/lhome/srauniy/test2/tflite-kws/kws.py", line 110, in process
    self.interpreter.set_tensor(self.input_details[0]['index'], indata)
  File "/lhome/srauniy/test2/sonu-venv/lib/python3.6/site-packages/tensorflow/lite/python/interpreter.py", line 640, in set_tensor
    self._interpreter.SetTensor(tensor_index, value)
ValueError: Cannot set tensor: Dimension mismatch. Got 320 but expected 160 for dimension 1 of input 0.

Model Training parameter:

{"data_url": "", "data_dir": "/lhome/my_dataset", "lr_schedule": "linear", "optimizer": "adam", "background_volume": 0.1, "l2_weight_decay": 0.0, "background_frequency": 0.8, "split_data": 1, "silence_percentage": 10.0, "unknown_percentage": 10.0, "time_shift_ms": 100.0, "sp_time_shift_ms": 0.0, "testing_percentage": 10, "validation_percentage": 10, "how_many_training_steps": "20000,20000,20000,20000,20000,20000", "eval_step_interval": 662, "learning_rate": "0.01,0.005,0.002,0.001,0.0005,0.0002", "batch_size": 128, "wanted_words": "heymercedes", "train_dir": "/lhome/srauniy/test2/models_data_v2_35_labels/ds_tc_resnet/", "save_step_interval": 100, "start_checkpoint": "", "verbosity": 0, "optimizer_epsilon": 1e-08, "resample": 0.15, "sp_resample": 0.0, "volume_resample": 0.0, "train": 1, "sample_rate": 16000, "clip_duration_ms": 1000, "window_size_ms": 30.0, "window_stride_ms": 10.0, "preprocess": "raw", "feature_type": "mfcc_tf", "preemph": 0.0, "window_type": "hann", "mel_lower_edge_hertz": 20.0, "mel_upper_edge_hertz": 7600.0, "log_epsilon": 1e-12, "dct_num_features": 40, "use_tf_fft": 0, "mel_non_zero_only": 1, "fft_magnitude_squared": false, "mel_num_bins": 80, "use_spec_augment": 1, "time_masks_number": 2, "time_mask_max_size": 25, "frequency_masks_number": 2, "frequency_mask_max_size": 7, "use_spec_cutout": 0, "spec_cutout_masks_number": 3, "spec_cutout_time_mask_size": 10, "spec_cutout_frequency_mask_size": 5, "return_softmax": 0, "novograd_beta_1": 0.95, "novograd_beta_2": 0.5, "novograd_weight_decay": 0.001, "novograd_grad_averaging": 0, "pick_deterministically": 1, "causal_data_frame_padding": 0, "wav": 1, "quantize": 0, "model_name": "ds_tc_resnet", "activation": "relu", "dropout": 0.0, "ds_filters": "128, 64, 64, 64, 128, 128", "ds_repeat": "1, 1, 1, 1, 1, 1", "ds_filter_separable": "1, 1, 1, 1, 1, 1", "ds_residual": "0, 1, 1, 1, 0, 0", "ds_padding": "'same', 'same', 'same', 'same', 'same', 'same'", "ds_kernel_size": "11, 13, 15, 17, 29, 1", "ds_stride": "1, 1, 1, 1, 1, 1", "ds_dilation": "1, 1, 1, 1, 2, 1", "ds_pool": "1, 1, 1, 1, 1, 1", "ds_max_pool": 0, "ds_scale": 1, "label_count": 3, "desired_samples": 16000, "window_size_samples": 480, "window_stride_samples": 160, "spectrogram_length": 98, "data_stride": 1, "data_frame_padding": null, "summaries_dir": "/models_data_v2_35_labels/ds_tc_resnet/logs/", "training": true}

Could you please suggest what changes I need to make to run for realtime inference? It was trained with one keyword hence the output is as : 

index_to_label = {1: 'unknown', 2: 'srewai', 0: 'silence'}

Thanks in advance!

SaneBow commented 2 years ago

You need to make sure the window/block size is consistent with what you used during training. By default the wrapper is for training parameters here: https://github.com/StuartIanNaylor/g-kws/blob/main/crrn.sh

--window_size_ms 40.0 \
--window_stride_ms 20.0 \

But I see you use "window_size_ms": 30.0, "window_stride_ms": 10.0 in your training. It's been sometime and I can't recall the exact meaning of these but you can try setting the --block-len-ms 10 or modifying the script to align with your training input.

srewai commented 2 years ago

Thank you so much! I am going to try your suggestion.

srewai commented 2 years ago

In your example the method get_next_audio_frame() but i could not find where is it defined in kws.py. Or did I miss something? Please help.

gkws = TFLiteKWS(args.model, [SILENCE, NOT_KW, 'keyword1', 'keyword2'])
while True:
    keyword = gkws.process(get_next_audio_frame())
    if keyword:
        # following up actions
SaneBow commented 2 years ago

You need to write that function yourself, or you can refer to the implementation in mic_streaming.py.

srewai commented 2 years ago

Thank you for the quick response!