elixir-nx / ortex

ONNX Runtime bindings for Elixir
MIT License
122 stars 15 forks source link

On main, using silero VAD, output shapes are inconsistent #37

Closed johns10 closed 2 months ago

johns10 commented 2 months ago

I'm using this code for VAD:

defmodule TodoApp.Audio.VADSplitter do
  use Membrane.Filter

  alias Membrane.{RawAudio, Buffer}
  alias Membrane.File.SplitEvent

  @chunk_duration Membrane.Time.milliseconds(30)
  @sample_rate 16_000
  @threshold 0.50

  def_input_pad(:input,
    accepted_format: %RawAudio{sample_format: :f32le, channels: 1, sample_rate: 16_000}
  )

  def_output_pad(:output,
    accepted_format: %RawAudio{sample_format: :f32le, channels: 1, sample_rate: 16_000}
  )

  @impl true
  def handle_init(_ctx, _opts) do
    model = Ortex.load(Path.join([:code.priv_dir(:todo_app), "models", "silero_vad.onnx"]))

    {[],
     %{
       model: model,
       h: Nx.broadcast(0.0, {2, 1, 64}),
       c: Nx.broadcast(0.0, {2, 1, 64}),
       chunk_size: nil,
       last: 0.0,
       queue: <<>>
     }}
  end

  @impl true
  def handle_stream_format(:input, stream_format, _ctx, state) do
    chunk_size = RawAudio.time_to_bytes(@chunk_duration, stream_format)
    state = Map.put(state, :chunk_size, chunk_size)

    {[stream_format: {:output, stream_format}], state}
  end

  @impl true

  def handle_buffer(:input, %Buffer{payload: payload} = buffer, _ctx, state) do
    %{model: model, h: h, c: c, queue: queue, chunk_size: chunk_size} = state
    dts = Buffer.get_dts_or_pts(buffer)

    {actions, state} =
      (queue <> payload)
      |> generate_chunks(chunk_size)
      |> Enum.reduce({[], state}, fn chunk, {actions, %{last: last} = state} ->
        if byte_size(chunk) == chunk_size do
          {prob, new_h, new_c} = do_predict(model, h, c, chunk)

          new_actions =
            case {last >= @threshold, prob >= @threshold} do
              {false, true} -> [:split, chunk]
              {false, false} -> []
              {true, false} -> [chunk, :split]
              {true, true} -> [chunk]
            end

          {actions ++ new_actions, %{state | h: new_h, c: new_c, last: prob, queue: <<>>}}
        else
          {actions, %{state | queue: chunk}}
        end
      end)

    membrane_actions =
      Enum.reduce(actions, {[], nil}, fn
        :split, {actions, nil} ->
          {[event: {:output, %SplitEvent{}}] ++ actions, nil}

        :split, {actions, %Buffer{} = buffer} ->
          {[buffer: {:output, buffer}, event: {:output, %SplitEvent{}}] ++ actions, nil}

        bin, {actions, nil} when is_binary(bin) ->
          {actions, %Buffer{payload: bin, dts: dts}}

        bin, {actions, %Buffer{payload: acc_bin}} when is_binary(bin) ->
          {actions, %Buffer{payload: acc_bin <> bin, dts: dts}}
      end)
      |> case do
        {actions, nil} -> actions
        {actions, %Buffer{} = buffer} -> [buffer: {:output, buffer}] ++ actions
      end

    {membrane_actions, state}
  end

  defp do_predict(model, h, c, audio) do
    input = Nx.from_binary(audio, :f32) |> Nx.new_axis(0)
    sr = Nx.tensor(@sample_rate)
    {output, new_h, new_c} = Ortex.run(model, {input, sr, h, c}) |> IO.inspect()

    # Log the shape of the output tensor
    IO.inspect(Nx.shape(sr), label: :sr_shape)
    IO.inspect(Nx.shape(h), label: :old_h_shape)
    IO.inspect(Nx.shape(c), label: :old_c_shape)
    IO.inspect(Nx.shape(output), label: "Output Shape")
    IO.inspect(Nx.shape(new_h), label: :new_h_shape)
    IO.inspect(Nx.shape(new_c), label: :new_c_shape)

    # prob = output |> Nx.squeeze() |> Nx.to_number()

    {0.0, new_h, new_c}
  end

  defp generate_chunks(samples, chunk_size) when byte_size(samples) >= chunk_size do
    <<chunk::binary-size(chunk_size), rest::binary>> = samples
    [chunk | generate_chunks(rest, chunk_size)]
  end

  defp generate_chunks(samples, _chunk_size) do
    [samples]
  end
end

On 0.1.9, I would consistently get tensors with shape: output: {1, 1} cn: {2, 1, 64} hn: {2, 1, 64}

After upgrading to main, the order of the return values is inconsistent. Logging the shapes over multiple iterations yeilds:

old_h_shape: {2, 1, 64}
old_c_shape: {2, 1, 64}
Output Shape: {1, 1}
new_h_shape: {2, 1, 64}
new_c_shape: {2, 1, 64}
old_h_shape: {2, 1, 64}
old_c_shape: {2, 1, 64}
Output Shape: {2, 1, 64}
new_h_shape: {2, 1, 64}
new_c_shape: {1, 1}
johns10 commented 2 months ago

By all accounts, this code "works"

    input = Nx.from_binary(audio, :f32) |> Nx.new_axis(0)
    sr = Nx.tensor(@sample_rate)
    # {output, new_h, new_c} = Ortex.run(model, {input, sr, h, c})

    {a, b, c} = Ortex.run(model, {input, sr, h, c})

    {output, new_h, new_c} =
      case [Nx.shape(a), Nx.shape(b), Nx.shape(c)] do
        [{1, 1}, {2, 1, 64}, {2, 1, 64}] -> {a, b, c}
        [{2, 1, 64}, {2, 1, 64}, {1, 1}] -> {c, a, b}
        [{2, 1, 64}, {1, 1}, {2, 1, 64}] -> {b, c, a}
      end

    prob = output |> Nx.squeeze() |> Nx.to_number()

    {prob, new_h, new_c}
  end

I'm locating the output based on the shape, and then assuming hn and cn come in the same order. The code runs, and the integration test I have locally "works."

mortont commented 2 months ago

The ordering is a known issue (https://github.com/elixir-nx/ortex/issues/26), we're waiting on upstream ort to have a 2.0 non-rc release before updating main to match. Unfortunately until then, I'd recommend using 0.1.9.

Is there a specific feature you're looking for in main that's not in the 0.1.9 release?

johns10 commented 2 months ago

I've been trying to load up the model exported from

https://github.com/pengzhendong/pyannote-onnx/blob/master/pyannote_onnx/pyannote_onnx.py

I keep getting opset errors, no matter what version I try to export from there.

mortont commented 2 months ago

Makes sense. I'm going to close this issue since the ordering of output is already tracked in #26.

In the meantime, I've noticed pytorch sometimes has coupling between its version and supported exportable opset versions so you may have luck trying to downgrade there. Also, this may be helpful for tracking the opset/IR compatibility matrix. Ortex 0.1.9 is on onnxruntime 1.14.1