sonos / tract

Tiny, no-nonsense, self-contained, Tensorflow and ONNX inference
Other
2.19k stars 210 forks source link

Getting error while loading a model #703

Closed 7r3nzy closed 2 years ago

7r3nzy commented 2 years ago

I am trying to load a silero vad model from here

Code:

         onnx()
             // load the model
             .model_for_path(model_path)?
             // specify input type and shape
             .with_input_names(["input", "h0", "c0"])?
             .with_output_names(["output", "hn", "cn"])?
             .with_input_fact(0,
                              InferenceFact::dt_shape( f32::datum_type(), tvec!(1, 1))
                              )?
             .with_input_fact(
                 1,
                 InferenceFact::dt_shape(f32::datum_type(), tvec!(2, 1, 64)),
             )?
             .with_input_fact(
                 2,
                 InferenceFact::dt_shape(f32::datum_type(), tvec!(2, 1, 64)),
             )?
             // optimize the model
             .into_optimized()?
             // make the model runnable and fix its inputs and outputs
             .into_runnable()

Error:

Failed analyse for node #121 "Pad_27" Pad

 Caused by:
     0: Infering facts
     1: Applying rule GivenRule { inputs[1] }
     2: Tensor datum type error: tensor is TDim, accessed as I64

I am able to load the model successfully if I skip into_optimized()?

On a separate note, on a successfully loaded model, I am not getting same output values as I get with utils_vad.py from silero. I will share the repro for that in a separate issue as soon as I get the chance.

kali commented 2 years ago

Hey, thanks for the report. I had to make a few fixes, but I got it to work (I think). Use the linked branch if you want to try, I will merge it ASAP.

All in one python inference (+ exported wav tensor for reuse in tract).

import torchaudio
import onnxruntime
import numpy

session = onnxruntime.InferenceSession("files/silero_vad.onnx")

wav, sr = torchaudio.load("en.wav")
assert sr == 16000

samples = wav.shape[1]
window_size_samples = 1536

h = numpy.zeros((2, 1, 64)).astype('float32')
c = numpy.zeros((2, 1, 64)).astype('float32')

io = dict()
io["wav"] = wav.numpy().astype('float32')
numpy.savez("io.npz", **io)

output = []

for (chunk, win) in enumerate(range(0, samples, window_size_samples)):
    x = wav[:,win:win + window_size_samples]
    if x.shape[1] < window_size_samples:
        x = torch.nn.functional.pad(x, (0, 0, 0, int(window_size_samples - len(x.shape[1]))))
    ort_inputs = {'input': x.numpy(), 'h0': h, 'c0': c}
    y, h, c = session.run(None, ort_inputs)
    output.append(y[(0,1,0)])

min_silence_duration_ms = 100
min_speech_duration_ms = 250
threshold = 0.5
neg_threshold = 0.35

triggered = False
current_speech = 0
temp_end = 0

min_silence_samples = min_silence_duration_ms * 16000 / 1000
min_speech_samples = min_speech_duration_ms * 16000 / 1000

for i, speech_prob in enumerate(output):
    if (speech_prob >= threshold) and temp_end:
        temp_end = 0

    if (speech_prob >= threshold) and not triggered:
        triggered = True
        current_speech = window_size_samples * i
        continue

    if (speech_prob < neg_threshold) and triggered:
        if not temp_end:
            temp_end = window_size_samples * i
        if (window_size_samples * i) - temp_end < min_silence_samples:
            continue
        else:
            if temp_end - current_speech > min_speech_samples:
                print(current_speech / 16000, temp_end / 16000)
            temp_end = 0
            triggered = False
            continue

And the same in rust with tract (except we're reading the wav from the io.npz generated by the python script).

use ndarray_npy::NpzReader;
use tract_ndarray::Array2;
use tract_onnx::{prelude::*, tract_hir::internal::DimLike};

fn main() -> TractResult<()> {
    let window_size_samples = 1536;
    let model = onnx()
        .model_for_path("../silero-vad-3.1/files/silero_vad.onnx")?
        .with_input_names(["input", "h0", "c0"])?
        .with_output_names(["output", "hn", "cn"])?
        .with_input_fact(
            0,
            InferenceFact::dt_shape(f32::datum_type(), tvec!(1, window_size_samples)),
        )?
        .with_input_fact(1, InferenceFact::dt_shape(f32::datum_type(), tvec!(2, 1, 64)))?
        .with_input_fact(2, InferenceFact::dt_shape(f32::datum_type(), tvec!(2, 1, 64)))?
        .into_optimized()?
        .into_runnable()?;
    let mut npz = NpzReader::new(std::fs::File::open("../silero-vad-3.1/io.npz")?)?;
    let wav: Array2<f32> = npz.by_name("wav.npy")?;
    let wav = wav.into_arc_tensor();
    let samples = wav.shape()[1];
    let mut h = Tensor::zero::<f32>(&[2, 1, 64])?;
    let mut c = Tensor::zero::<f32>(&[2, 1, 64])?;
    let mut output: Vec<f32> = vec![];

    for ix in 0..samples.divceil(window_size_samples) {
        let offset = ix * window_size_samples;
        let mut x = Tensor::zero::<f32>(&[1, window_size_samples])?;
        let chunk_len = (samples - offset).min(window_size_samples);
        x.assign_slice(0..chunk_len, &wav, offset..offset + chunk_len, 1)?;
        let mut outputs = model.run(tvec!(x, h, c))?;
        c = outputs.remove(2).into_tensor();
        h = outputs.remove(1).into_tensor();
        output.push(outputs[0].as_slice::<f32>()?[1]);
    }

    let min_silence_duration_ms = 100;
    let min_speech_duration_ms = 250;
    let threshold = 0.5;
    let neg_threshold = 0.35;
    let min_silence_samples = min_silence_duration_ms * 16000 / 1000;
    let min_speech_samples = min_speech_duration_ms * 16000 / 1000;

    let mut triggered = false;
    let mut current_speech = 0;
    let mut temp_end = 0;

    for (ix, speech_prob) in output.into_iter().enumerate() {
        if speech_prob >= threshold && temp_end != 0 {
            temp_end = 0;
        }
        if speech_prob >= threshold && !triggered {
            triggered = true;
            current_speech = window_size_samples * ix;
        } else if speech_prob < neg_threshold && triggered {
            if temp_end == 0 {
                temp_end = window_size_samples * ix;
            }
            if (window_size_samples * ix) - temp_end >= min_silence_samples {
                if temp_end - current_speech > min_speech_samples {
                    println!("{} {}", current_speech as f32 / 16000., temp_end as f32 / 16000.);
                }
                temp_end = 0;
                triggered = false
            }
        }
    }

    Ok(())
}
7r3nzy commented 2 years ago

Hey, this is really great, thank you so much for going extra mile. I am definitely going to try today and will let you know how it goes :).

7r3nzy commented 2 years ago

Works perfectly! Thanks a lot. Closing the issue, since I see you have merged the changes as well. I should mention that incorrect values issue was my own, your code helped there as well though :).

7r3nzy commented 2 years ago

Also, I am interested in converting this as an example, would you like to take a PR for this? I think that should resolve https://github.com/sonos/tract/issues/114

kali commented 2 years ago

Please do! Indeed, it would be one answer to 114. There are other ways to do streaming (through the pulse system) but this is a perfectly valid approach.

chriskyndrid commented 1 year ago

+1, @kali thank for you for the rust example! I use tract for various models, and it's really great! In this case I needed GPU support, so I implemented with Ort, but your example made that translation a lot easier.