seanghay / uvr-mdx-infer

Ultimate Vocal Remover Inference CLI
41 stars 7 forks source link

Can I run these onnx model in browser with onnxruntime-web? #5

Closed Arvin8613 closed 3 months ago

Arvin8613 commented 3 months ago

These onnx models are awsome ! I want to run UVR-MDX-NET-Inst_HQ_3.onnx in browser with onnxruntime-web like what transformer.js is doing.

I've try my best to infer in web, but the output always not correct comparing with the output running out with python.

Anyone could help me out?

Here are my code running in web


<!DOCTYPE html>
<html lang="en">
<head>
  <meta charset="UTF-8">
  <meta name="viewport" content="width=device-width, initial-scale=1.0">
  <title>Audio Processing with ONNX</title>
  <script src="https://cdn.jsdelivr.net/npm/onnxruntime-web@latest/dist/ort.min.js"></script>
  <script src="https://cdn.jsdelivr.net/npm/fft.js@4.0.4/lib/fft.min.js"></script>
</head>
<body>
  <h1>Audio Processing with ONNX</h1>
  <input type="file" id="audioInput" accept=".wav">
  <button onclick="processAudio()">Process Audio</button>
  <div id="output"></div>

  <script>
    async function preprocessAudio(file) {
      const audioContext = new (window.AudioContext || window.webkitAudioContext)();
      const arrayBuffer = await file.arrayBuffer();
      const audioBuffer = await audioContext.decodeAudioData(arrayBuffer);

      const channelData = audioBuffer.getChannelData(0);

      const n_fft = 4096;
      const hop_length = 1024;
      const hannWin = hannWindow(n_fft);
      const stftData = stft(channelData, n_fft, hop_length, hannWin);

      const requiredSize = 4 * 3072 * 256;  // 3145728

      const inputTensorData = new Float32Array(requiredSize);

      let index = 0;
      for (let i = 0; i < stftData.length; i++) {
        const complexArray = stftData[i];
        for (let j = 0; j < complexArray.length; j++) {
          if (index < requiredSize) {
            inputTensorData[index++] = complexArray[j];
          }
        }
      }

      return new ort.Tensor('float32', inputTensorData, [1, 4, 3072, 256]);
    }

    function hannWindow(length) {
      const win = new Float32Array(length);
      for (let i = 0; i < length; i++) {
        win[i] = 0.5 * (1 - Math.cos((2 * Math.PI * i) / (length - 1)));
      }
      return win;
    }

    function stft(signal, n_fft, hop_length, win) {
      const stftData = [];
      const fft = new FFT(n_fft);
      for (let i = 0; i < signal.length - n_fft; i += hop_length) {
        const segment = signal.slice(i, i + n_fft);
        const windowed = segment.map((sample, index) => sample * win[index]);
        const out = fft.createComplexArray();
        fft.realTransform(out, windowed);
        fft.completeSpectrum(out);
        stftData.push(out);
      }
      return stftData;
    }

    async function processAudio() {
      const fileInput = document.getElementById('audioInput');
      const file = fileInput.files[0];
      if (!file) {
        alert('Please select a WAV file first.');
        return;
      }

      const inputTensor = await preprocessAudio(file);
      const session = await ort.InferenceSession.create('UVR-MDX-NET-Inst_HQ_3.onnx');
      const inputs = { input: inputTensor };
      const results = await session.run(inputs);
      const outputData = results.output.data;

      const istftData = istft(outputData);

      playAudio(istftData);
    }

    function istft(data) {
      const n_fft = 4096;
      const hop_length = 1024;
      const hannWin = hannWindow(n_fft);
      const numSegments = Math.floor(data.length / (2 * n_fft));
      const ifft = new FFT(n_fft);

      const istftData = new Float32Array(numSegments * hop_length);
      for (let i = 0; i < numSegments; i++) {
        const real = data.slice(i * 2 * n_fft, (i + 1) * 2 * n_fft);
        const imag = data.slice((i + 1) * 2 * n_fft, (i + 2) * 2 * n_fft);
        const complex = new Float32Array(real.length * 2);
        for (let j = 0; j < real.length; j++) {
          complex[j * 2] = real[j];
          complex[j * 2 + 1] = imag[j];
        }
        const out = ifft.createComplexArray();
        ifft.inverseTransform(out, complex);
        const segment = out.map((val, index) => val * hannWin[index]);

        for (let j = 0; j < segment.length; j++) {
          istftData[i * hop_length + j] += segment[j];
        }
      }
      return istftData;
    }

    function playAudio(data) {
      const audioContext = new (window.AudioContext || window.webkitAudioContext)();
      const buffer = audioContext.createBuffer(1, data.length, 44100);
      buffer.copyToChannel(data, 0);
      const source = audioContext.createBufferSource();
      source.buffer = buffer;
      source.connect(audioContext.destination);
      source.start();
    }
  </script>
</body>
</html>```
seanghay commented 3 months ago

I had the same idea but I never figured out the solution for STFT/inverse STFT function for JavaScript.