emscripten-core / emscripten

Emscripten: An LLVM-to-WebAssembly Compiler
Other
25.83k stars 3.31k forks source link

Significant Inference Time Increase When Compiling with Emscripten for WASM #20570

Open yusuf-ilgun opened 1 year ago

yusuf-ilgun commented 1 year ago

Description:

I have been working on a C++ inference project using onnxruntime. The goal is to compile the C++ code into WebAssembly (WASM) for use in jitsi-meet.

During my testing, I noticed a significant performance degradation in inference time based on how I compiled with Emscripten: Native (no WASM): Inference runs in an average of 1ms. WASM (Native): Average inference time increases to about 5ms. WASM (Web): Average inference time further increases to about 11ms.

I've tried multiple versions of Emscripten. Version 3.1.41 is the latest that works for me due to an atob issue in subsequent versions. I believe solving the native compilation problem might also address the web issue.

Environment Information:

Emscripten Version: 3.1.41 (latest working version)
Operating System: [Ubuntu 22.04]
Browser/Version: [Latest Chrome/Firefox]
Other related software versions:
    ONNX Runtime: [1.17.0]

command line in full:

Tried with and without optimizations, msimd, flto.

emcc \
        -O3 \
    -g0 \
    -msimd128 -flto -frtti \
    -sALLOW_MEMORY_GROWTH=1 \
    -sMALLOC=emmalloc \
    -sENVIRONMENT="node" \
    --embed-file dns48_stream_depth=4_stride=128_quantized.onnx \
    --preload-file audio.wav \
    -sASSERTIONS=0 \
    -sSTACK_SIZE=128KB \
    -sINITIAL_MEMORY=22MB \
    -sSINGLE_FILE=1 \
    -L. -lonnxruntime_webassembly -lfvad\
     denoise_streaming.cpp wave.c decode.c \
    -o rnnoise-node.js \

Code might not be relevant but here you go:

#include <iostream>
#include <vector>
#include <cstdint>
#include <cstdio>
#include <cstring>
#include <cstdlib>

#include "include/onnxruntime_cxx_api.h"
#include "include/fvad.h"
#include "include/wave.h"
#include "include/decode.h"

std::vector<float> read_audio_file(const std::string& filepath) {
    FILE *fp = fopen(filepath.c_str(), "r");
    if (fp == NULL) {
        std::cerr << "Error opening file: " << filepath << std::endl;
        exit(1);
    }

    fseek(fp, 0, SEEK_END);
    long size = ftell(fp);
    fseek(fp, 0, SEEK_SET);

    printf("file size = %ld\n", size);

    uint8_t* data = (uint8_t*)malloc(size);
    if (!data) {
        std::cerr << "Failed to allocate memory" << std::endl;
        fclose(fp);
        exit(1);
    }

    fread(data, 1, size, fp);
    fclose(fp);

    WAVFile file = WAV_ParseFileData(data);

    int sample_count = file.data_length / (file.header.bits_per_sample / 8);

    std::cout << "file.data_length: " << file.data_length << std::endl;
    std::cout << "file.header.bits_per_sample: " << file.header.bits_per_sample << std::endl;
    std::cout << "Sample count: " << sample_count << std::endl;

    // Allocating memory for `noisy` and the float* channel buffer
    std::vector<float> noisy(sample_count);
    std::vector<float*> channels(file.header.number_of_channels);

    // If you have multi-channel audio, needs to prepared a buffer for each channel.
    // However, as we are processing mono, only one channel is enough.
    channels[0] = noisy.data();

    // Pass the `channels.data()` as the float** parameter
    pcm16_decode(file.data, 0, channels.data(), file.header.number_of_channels, sample_count);

    free(data);

    // Returning the decoded PCM data as a vector<float>
    return noisy;
}

extern "C" {
    // Define the input and output names
    const char* input_names[] = {
                "input",
                "frame_buffer",
                "frame_num.1",
                "variance.1",
                "resample_input_frame.1",
                "resample_out_frame.1",
                "h0.1",
                "c0.1",
                "conv_state.1"
    };

    // Define the names of the output tensors
    const char* output_names[] = {
                "output",
                "out_frame_buffer",
                "frame_num",
                "variance",
                "resample_input_frame",
                "resample_out_frame",
                "h0",
                "c0",
                "conv_state"
    };

    Ort::MemoryInfo memoryInfo = Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault);

    int power(int base, int exponent) {
        int result = 1;

        for (int i = 0; i < exponent; ++i) {
            result *= base;
        }

        return result;
    }

    // Define the ConvStateSize tuple type
    struct ConvStateSize {
        int dim1;
        int dim2;
        int dim3;
    };

    struct InputShapes {
        int64_t input_shape[2];
        size_t input_shape_len;
        size_t input_len;

        int64_t frame_buffer_shape[2];
        size_t frame_buffer_shape_len;
        size_t frame_buffer_len;

        int64_t frame_num_1_shape[1];
        size_t frame_num_1_shape_len;
        size_t frame_num_1_len;

        int64_t variance_1_shape[1];
        size_t variance_1_shape_len;
        size_t variance_1_len;

        int64_t resample_input_frame_1_shape[2];
        size_t resample_input_frame_1_shape_len;
        size_t resample_input_frame_1_len;

        int64_t resample_out_frame_1_shape[2];
        size_t resample_out_frame_1_shape_len;
        size_t resample_out_frame_1_len;

        int64_t h0_1_shape[3];
        size_t h0_1_shape_len;
        size_t h0_1_len;

        int64_t c0_1_shape[3];
        size_t c0_1_shape_len;
        size_t c0_1_len;

        int64_t conv_state_1_shape[2];
        size_t conv_state_1_shape_len;
        size_t conv_state_1_len;
    };

    struct ModelInputs {
        std::vector<float> input;
        std::vector<float> frame_buffer;
        std::vector<int64_t> frame_num_1;
        std::vector<float> variance_1;
        std::vector<float> resample_input_frame_1;
        std::vector<float> resample_out_frame_1;
        std::vector<float> h0_1;
        std::vector<float> c0_1;
        std::vector<float> conv_state_1;
        InputShapes inputShapes;
    };

    struct RNNModel {
        int place_holder;
    };

    struct DenoiseState {
        Ort::Env* env;
        Ort::SessionOptions* sessionOptions;
        Ort::Session* session;
        Ort::RunOptions *run_options;
        std::string modelPath;

        ModelInputs modelInputs;
        Fvad *vad;

        int depth;
        int hidden;
        int frame_length;
        int frame_number;
        int stride;
    };

    int rnnoise_get_size() {
        return sizeof(DenoiseState);
    }

    int rnnoise_init(DenoiseState* state, RNNModel *model) {
        state->modelPath = "dns48_stream_depth=4_stride=128_quantized.onnx";
        state->env = new Ort::Env(ORT_LOGGING_LEVEL_FATAL, "ONNXRuntimeExample");
        state->sessionOptions = new Ort::SessionOptions();
        //state->sessionOptions->SetGraphOptimizationLevel(ORT_ENABLE_ALL);
        //state->sessionOptions->EnableProfiling("my_profile_file.profile");
        //state->sessionOptions->SetIntraOpNumThreads(4); // Set the number of intra-op threads.
        //state->sessionOptions->SetInterOpNumThreads(2); // Set the number of inter-op threads.
        state->session = new Ort::Session(*(state->env), state->modelPath.c_str(), *(state->sessionOptions));

        state->depth = 4;
        state->hidden = 48;
        state->frame_length = 128;
        state->frame_number = 0;
        state->stride = 128;

        int depth = state->depth;
        int hidden = state->hidden;
        int frame_length = state->frame_length;

        Ort::AllocatorWithDefaultOptions allocator;

        state->modelInputs.inputShapes.input_shape[0] = 1;
        state->modelInputs.inputShapes.input_shape[1] = frame_length;
        state->modelInputs.inputShapes.input_shape_len = sizeof(state->modelInputs.inputShapes.input_shape) / sizeof(state->modelInputs.inputShapes.input_shape[0]);
        state->modelInputs.inputShapes.input_len = state->modelInputs.inputShapes.input_shape[0] * state->modelInputs.inputShapes.input_shape[1];

        state->modelInputs.inputShapes.frame_buffer_shape[0] = 1;
        state->modelInputs.inputShapes.frame_buffer_shape[1] = 362;
        state->modelInputs.inputShapes.frame_buffer_shape_len = sizeof(state->modelInputs.inputShapes.frame_buffer_shape) / sizeof(state->modelInputs.inputShapes.frame_buffer_shape[0]);
        state->modelInputs.inputShapes.frame_buffer_len = state->modelInputs.inputShapes.frame_buffer_shape[0] * state->modelInputs.inputShapes.frame_buffer_shape[1];

        state->modelInputs.inputShapes.frame_num_1_shape[0] = 1;
        state->modelInputs.inputShapes.frame_num_1_shape_len = sizeof(state->modelInputs.inputShapes.frame_num_1_shape) / sizeof(state->modelInputs.inputShapes.frame_num_1_shape[0]);
        state->modelInputs.inputShapes.frame_num_1_len = state->modelInputs.inputShapes.frame_num_1_shape[0];

        state->modelInputs.inputShapes.variance_1_shape[0] = 1;
        state->modelInputs.inputShapes.variance_1_shape_len = sizeof(state->modelInputs.inputShapes.variance_1_shape) / sizeof(state->modelInputs.inputShapes.variance_1_shape[0]);
        state->modelInputs.inputShapes.variance_1_len = state->modelInputs.inputShapes.variance_1_shape[0];

        state->modelInputs.inputShapes.resample_input_frame_1_shape[0] = 1;
        state->modelInputs.inputShapes.resample_input_frame_1_shape[1] = frame_length;
        state->modelInputs.inputShapes.resample_input_frame_1_shape_len = sizeof(state->modelInputs.inputShapes.resample_input_frame_1_shape) / sizeof(state->modelInputs.inputShapes.resample_input_frame_1_shape[0]);
        state->modelInputs.inputShapes.resample_input_frame_1_len = state->modelInputs.inputShapes.resample_input_frame_1_shape[0] * state->modelInputs.inputShapes.resample_input_frame_1_shape[1];  

        state->modelInputs.inputShapes.resample_out_frame_1_shape[0] = 1;
        state->modelInputs.inputShapes.resample_out_frame_1_shape[1] = frame_length;
        state->modelInputs.inputShapes.resample_out_frame_1_shape_len = sizeof(state->modelInputs.inputShapes.resample_out_frame_1_shape) / sizeof(state->modelInputs.inputShapes.resample_out_frame_1_shape[0]);
        state->modelInputs.inputShapes.resample_out_frame_1_len = state->modelInputs.inputShapes.resample_out_frame_1_shape[0] * state->modelInputs.inputShapes.resample_out_frame_1_shape[1];  

        state->modelInputs.inputShapes.h0_1_shape[0] = 2;
        state->modelInputs.inputShapes.h0_1_shape[1] = 1;
        state->modelInputs.inputShapes.h0_1_shape[2] = 384;
        state->modelInputs.inputShapes.h0_1_shape_len = sizeof(state->modelInputs.inputShapes.h0_1_shape) / sizeof(state->modelInputs.inputShapes.h0_1_shape[0]);
        state->modelInputs.inputShapes.h0_1_len = state->modelInputs.inputShapes.h0_1_shape[0] * state->modelInputs.inputShapes.h0_1_shape[1] * state->modelInputs.inputShapes.h0_1_shape[2];

        state->modelInputs.inputShapes.c0_1_shape[0] = 2;
        state->modelInputs.inputShapes.c0_1_shape[1] = 1;
        state->modelInputs.inputShapes.c0_1_shape[2] = 384;
        state->modelInputs.inputShapes.c0_1_shape_len = sizeof(state->modelInputs.inputShapes.c0_1_shape) / sizeof(state->modelInputs.inputShapes.c0_1_shape[0]);
        state->modelInputs.inputShapes.c0_1_len = state->modelInputs.inputShapes.c0_1_shape[0] * state->modelInputs.inputShapes.c0_1_shape[1] * state->modelInputs.inputShapes.c0_1_shape[2];

        state->modelInputs.inputShapes.conv_state_1_shape[0] = 1;
        state->modelInputs.inputShapes.conv_state_1_shape[1] = 13444;
        state->modelInputs.inputShapes.conv_state_1_shape_len = sizeof(state->modelInputs.inputShapes.conv_state_1_shape) / sizeof(state->modelInputs.inputShapes.conv_state_1_shape[0]);
        state->modelInputs.inputShapes.conv_state_1_len = state->modelInputs.inputShapes.conv_state_1_shape[0] * state->modelInputs.inputShapes.conv_state_1_shape[1];

         // Initialize vectors
        state->modelInputs.input.resize(state->modelInputs.inputShapes.input_len);
        state->modelInputs.frame_buffer.resize(state->modelInputs.inputShapes.frame_buffer_len);
        state->modelInputs.frame_num_1.resize(state->modelInputs.inputShapes.frame_num_1_len);
        state->modelInputs.variance_1.resize(state->modelInputs.inputShapes.variance_1_len);
        state->modelInputs.resample_input_frame_1.resize(state->modelInputs.inputShapes.resample_input_frame_1_len);
        state->modelInputs.resample_out_frame_1.resize(state->modelInputs.inputShapes.resample_out_frame_1_len);
        state->modelInputs.h0_1.resize(state->modelInputs.inputShapes.h0_1_len);
        state->modelInputs.c0_1.resize(state->modelInputs.inputShapes.c0_1_len);
        state->modelInputs.conv_state_1.resize(state->modelInputs.inputShapes.conv_state_1_len);

        state->vad = fvad_new();
        if (!state->vad) {
            printf("fvad error: Not initiated\n");
        }

        // Set vad mode and sample rate
        int ret;
        ret = fvad_set_mode(state->vad, 2);  // Mode: 0, 1, 2 or 3
        if (ret < 0) {
            printf("fvad error: Invalid mode\n");
        }
        ret = fvad_set_sample_rate(state->vad, 16000);  // sample rate
        if (ret < 0) {
            printf("fvad error: Invalid sample rate\n");
        }

        return 0;
    }

    DenoiseState *rnnoise_create(RNNModel *model) {
        DenoiseState *st = new DenoiseState();
        rnnoise_init(st, model);
        return st;
    }

    void rnnoise_destroy(DenoiseState* state) {
        delete state->session;
        delete state->sessionOptions;
        delete state->env;
    }

    float rnnoise_process_frame(DenoiseState* state, float* out, const float* in) { 
        std::vector<float> frames(in, in + state->frame_length);

        std::vector<int16_t> int16_samples(state->frame_length);
        for (int i = 0; i < state->frame_length; i++) {
            int16_samples[i] = static_cast<int16_t>(frames[i] * 32768.0f);
        }

        //int vad_result = fvad_process(state->vad, int16_samples.data(), state->frame_length);
/*  !!!!!!!!!!!!!!!!!!!!!!!! DONT FORGET TO REMOVE !!!!!!!!!!!!!!!!!!!!!!!!!!!! */
        int vad_result = 1;
/*  !!!!!!!!!!!!!!!!!!!!!!!! DONT FORGET TO REMOVE !!!!!!!!!!!!!!!!!!!!!!!!!!!! */
        if (vad_result < 0) {
            // Handle error: invalid frame length.
            printf("fvad error: invalid frame length\n");
        } else if(vad_result == 0) {
            printf("fvad detected silence\n");
            for (int i = 0; i < state->frame_length; i++) {
                out[i] = 0;
            }
        } else if (vad_result > 0) {
            std::vector<Ort::Value> input_values;

            std::copy(frames.begin(), frames.end(), state->modelInputs.input.begin());
            // Create Tensor for input_frame
            input_values.push_back(Ort::Value::CreateTensor<float>(
                memoryInfo,
                state->modelInputs.input.data(), 
                state->modelInputs.inputShapes.input_len, 
                state->modelInputs.inputShapes.input_shape,
                state->modelInputs.inputShapes.input_shape_len
            ));

            // Create Tensor for states
            input_values.push_back(Ort::Value::CreateTensor<float>(
                memoryInfo,
                state->modelInputs.frame_buffer.data(),
                state->modelInputs.inputShapes.frame_buffer_len,
                state->modelInputs.inputShapes.frame_buffer_shape,
                state->modelInputs.inputShapes.frame_buffer_shape_len
            ));

            // Create Tensor for atten_lim_db
            input_values.push_back(Ort::Value::CreateTensor<int64_t>(
                memoryInfo,
                state->modelInputs.frame_num_1.data(),
                state->modelInputs.inputShapes.frame_num_1_len,
                state->modelInputs.inputShapes.frame_num_1_shape,
                state->modelInputs.inputShapes.frame_num_1_shape_len
            ));

            input_values.push_back(Ort::Value::CreateTensor<float>(
                memoryInfo,
                state->modelInputs.variance_1.data(),
                state->modelInputs.inputShapes.variance_1_len,
                state->modelInputs.inputShapes.variance_1_shape,
                state->modelInputs.inputShapes.variance_1_shape_len
            ));

            input_values.push_back(Ort::Value::CreateTensor<float>(
                memoryInfo,
                state->modelInputs.resample_input_frame_1.data(),
                state->modelInputs.inputShapes.resample_input_frame_1_len,
                state->modelInputs.inputShapes.resample_input_frame_1_shape,
                state->modelInputs.inputShapes.resample_input_frame_1_shape_len
            ));

            input_values.push_back(Ort::Value::CreateTensor<float>(
                memoryInfo,
                state->modelInputs.resample_out_frame_1.data(),
                state->modelInputs.inputShapes.resample_out_frame_1_len,
                state->modelInputs.inputShapes.resample_out_frame_1_shape,
                state->modelInputs.inputShapes.resample_out_frame_1_shape_len
            ));

            input_values.push_back(Ort::Value::CreateTensor<float>(
                memoryInfo,
                state->modelInputs.h0_1.data(),
                state->modelInputs.inputShapes.h0_1_len,
                state->modelInputs.inputShapes.h0_1_shape,
                state->modelInputs.inputShapes.h0_1_shape_len
            ));

            input_values.push_back(Ort::Value::CreateTensor<float>(
                memoryInfo,
                state->modelInputs.c0_1.data(),
                state->modelInputs.inputShapes.c0_1_len,
                state->modelInputs.inputShapes.c0_1_shape,
                state->modelInputs.inputShapes.c0_1_shape_len
            ));

            input_values.push_back(Ort::Value::CreateTensor<float>(
                memoryInfo,
                state->modelInputs.conv_state_1.data(),
                state->modelInputs.inputShapes.conv_state_1_len,
                state->modelInputs.inputShapes.conv_state_1_shape,
                state->modelInputs.inputShapes.conv_state_1_shape_len
            ));

            // TIME CALCULATION START
            std::chrono::steady_clock::time_point start = std::chrono::steady_clock::now();
            std::vector<Ort::Value> outputTensors = state->session->Run(
                Ort::RunOptions{nullptr},
                input_names, 
                input_values.data(), 
                input_values.size(),
                output_names,
                9
            ); 
            std::chrono::steady_clock::time_point end2 = std::chrono::steady_clock::now();
            std::chrono::duration<double> elapsed_seconds2 = end2 - start;
            double elapsed_time2 = elapsed_seconds2.count();
            printf("Elapsed time 2: %.8f seconds\n",elapsed_time2);
            // TIME CALCULATION END

            memcpy(out, outputTensors[0].GetTensorMutableData<float>(), state->frame_length * sizeof(float));

            // Retrieve the outputs from outputTensors
            float* frame_buffer = outputTensors[1].GetTensorMutableData<float>();
            float* variance_1 = outputTensors[3].GetTensorMutableData<float>();
            float* resample_input_frame_1 = outputTensors[4].GetTensorMutableData<float>();
            float* resample_out_frame_1 = outputTensors[5].GetTensorMutableData<float>();
            float* h0_1 = outputTensors[6].GetTensorMutableData<float>();
            float* c0_1 = outputTensors[7].GetTensorMutableData<float>();
            float* conv_state_1 = outputTensors[8].GetTensorMutableData<float>();
            // and update the states in modelInputs
            std::copy(frame_buffer, frame_buffer + state->modelInputs.inputShapes.frame_buffer_len, state->modelInputs.frame_buffer.begin());
            std::copy(variance_1, variance_1 + state->modelInputs.inputShapes.variance_1_len, state->modelInputs.variance_1.begin());
            std::copy(resample_out_frame_1, resample_out_frame_1 + state->modelInputs.inputShapes.resample_out_frame_1_len, state->modelInputs.resample_out_frame_1.begin());
            std::copy(resample_out_frame_1, resample_out_frame_1 + state->modelInputs.inputShapes.resample_out_frame_1_len, state->modelInputs.resample_out_frame_1.begin());
            std::copy(h0_1, h0_1 + state->modelInputs.inputShapes.h0_1_len, state->modelInputs.h0_1.begin());
            std::copy(c0_1, c0_1 + state->modelInputs.inputShapes.c0_1_len, state->modelInputs.c0_1.begin());
            std::copy(conv_state_1, conv_state_1 + state->modelInputs.inputShapes.conv_state_1_len, state->modelInputs.conv_state_1.begin());

            state->modelInputs.frame_num_1[0] += 1;
        }

        return 1.0; // placeholder for now
    }

    int main()
    {
        printf("starts!\n");
        // Load your ONNX model and other setup as needed
        RNNModel model;

        // Create and initialize DenoiseState
        DenoiseState *state = rnnoise_create(&model);

        rnnoise_init(state, &model);

        // Assume you have a function read_audio_file that returns a vector<float> containing your audio samples
        std::vector<float> input_data = read_audio_file("audio.wav");
        std::vector<float> output_data;

        const size_t frame_size = 128;

        printf("input data size = %zu\n",input_data.size());
        for (size_t i = 0; i + frame_size <= input_data.size(); i += frame_size) {
            // TIME CALCULATION START
            std::chrono::steady_clock::time_point start = std::chrono::steady_clock::now();

            std::vector<float> buffer(input_data.begin() + i, input_data.begin() + i + frame_size);

            rnnoise_process_frame(state, buffer.data(), buffer.data());

            // Append the processed frame to output data
            output_data.insert(output_data.end(), buffer.begin(), buffer.end());

            std::chrono::steady_clock::time_point end2 = std::chrono::steady_clock::now();
            std::chrono::duration<double> elapsed_seconds2 = end2 - start;
            double elapsed_time2 = elapsed_seconds2.count();
            printf("Elapsed time 2: %.8f seconds\n",elapsed_time2);
        }
    }
}

Steps to Reproduce: Compile the C++ inference code with the above options(Or any other onnx inference code). Test the inference time in a native environment. Compile the code for WASM (both native and web). Compare the inference times. (I can provide full project files if it is needed)

Expected Result: The performance difference between native and WASM-compiled versions should be minimal.

Actual Result: Inference times increased significantly when compiling with Emscripten for WASM.

Additional Notes: This is not rnnoise as the function names suggested, i am just trying to implement my own solution instead rnnoise. So I didn't change the function names for convenience.

Any assistance or insight into resolving this performance discrepancy would be greatly appreciated. Any advice about compilation options, emscripten version, or even a better language for the wasm would also be appreciated.

sbc100 commented 1 year ago

What is WASM (Native) here as compared to WASM (Native)?

yusuf-ilgun commented 1 year ago

sorry for not being more clear, with wasm native i meant: nodejs. Let me rephrase all 3.

Native (C++ on Ubuntu, no WASM): using gcc as a compiler running directly on ubuntu.

WASM on Node.js (Local): compiled with -sENVIRONMENT='node' running on node v20

WASM (Web): Compiled with -sENVIRONMENT='web' running on browsers, I used 3 browsers so far, and it was all same (Chrome, firefox, safari: Browsers are always up to date, so latest version for all)

sbc100 commented 1 year ago

Hmm that is very strange that you are seeing perf differences between Node.js and chrome since in both cases we are dealing with the same v8 engines. Perhaps a version difference might account for it.

Can you try building with -sENVIRONMENT=node,web so you can literally run the exact same build/binary in those two environment?