ggerganov / whisper.cpp

Port of OpenAI's Whisper model in C/C++
MIT License
34.89k stars 3.56k forks source link

Subsequent requests taking more time when sending audio via websockets #1197

Closed btofficiel closed 1 year ago

btofficiel commented 1 year ago

I am trying to transcribe audio bytes coming as websocket messages and while the transcription is coming out accurately, the messages that come after the first one take more time and generate more segments for the same audio.

Can anyone help me by pointing out why that maybe happening

int main() {
    YAML::Node config = YAML::LoadFile("config/config.yaml");
    std::string model_path = config["model_path"].as<std::string>();;
    struct whisper_context * ctx = whisper_init_from_file(model_path.c_str());
    struct whisper_full_params wparams = whisper_full_default_params(whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY);
    wparams.language = "en";
    wparams.no_context = true;
    wparams.audio_ctx = 0;
    wparams.prompt_n_tokens = 0;
    wparams.prompt_tokens = nullptr;
    uWS::App().ws<UserData>("/ping", {
        /* Settings */
        .compression = uWS::CompressOptions(uWS::DEDICATED_COMPRESSOR_4KB | uWS::DEDICATED_DECOMPRESSOR),
        .maxPayloadLength = 100 * 1024 * 1024,
        .idleTimeout = 16,
        .maxBackpressure = 100 * 1024 * 1024,
        .closeOnBackpressureLimit = false,
        .resetIdleTimeoutOnSend = false,
        .sendPingsAutomatically = true,
        /* Handlers */
        .upgrade = nullptr,
        .open = [](auto *ws) {
            /* Open event here, you may access ws->getUserData() which points to a PerSocketData struct */
            UserData *userData = (UserData *) ws->getUserData();
            userData->pcm.push_back(0.0);

        },
        .message = [&, ctx, wparams](auto *ws, std::string_view message, uWS::OpCode opCode) {
            std::vector<float> pcmf32_samples = AudioLoader::generatePCMF32samples(message);

            if (whisper_full(ctx, wparams, pcmf32_samples.data(), pcmf32_samples.size()) != 0) {
                fprintf(stderr, "failed to process audio\n");
            }

            int n_segments = whisper_full_n_segments(ctx);
            std::cout << "N_Segments: " << n_segments << std::endl;
            for (int i = 0; i < n_segments; ++i) {
                const char * text = whisper_full_get_segment_text(ctx, i);
                ws->send(create_response(text, i == n_segments-1), uWS::OpCode::TEXT, false);
            }

        },
        .drain = [](auto */*ws*/) {
            /* Check ws->getBufferedAmount() here */
        },
        .ping = [](auto */*ws*/, std::string_view) {
            /* Not implemented yet */
        },
        .pong = [](auto */*ws*/, std::string_view) {
            /* Not implemented yet */
        },
        .close = [](auto */*ws*/, int /*code*/, std::string_view /*message*/) {
            /* You may access ws->getUserData() here */
        }
    }).listen(3000, [](auto *listen_socket) {
            if (listen_socket) {
                        std::cout << "Listening on port " << 3000 << std::endl;
            }
        }).run();

        std::cout << "Failed to listen on port 3000" << std::endl;
}
btofficiel commented 1 year ago

Did some more reading into whisper.h header file and fixed the issue.

int main() {
    YAML::Node config = YAML::LoadFile("config/config.yaml");
    std::string model_path = config["model_path"].as<std::string>();;
    struct whisper_context * ctx = whisper_init_from_file_no_state(model_path.c_str());
    struct whisper_full_params wparams = whisper_full_default_params(whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY);
    uWS::App().ws<UserData>("/ping", {
        /* Settings */
        .compression = uWS::CompressOptions(uWS::DEDICATED_COMPRESSOR_4KB | uWS::DEDICATED_DECOMPRESSOR),
        .maxPayloadLength = 100 * 1024 * 1024,
        .idleTimeout = 16,
        .maxBackpressure = 100 * 1024 * 1024,
        .closeOnBackpressureLimit = false,
        .resetIdleTimeoutOnSend = false,
        .sendPingsAutomatically = true,
        /* Handlers */
        .upgrade = nullptr,
        .open = [](auto *ws) {
            /* Open event here, you may access ws->getUserData() which points to a PerSocketData struct */
            UserData *userData = (UserData *) ws->getUserData();
            userData->pcm.push_back(0.0);

        },
        .message = [&, ctx, wparams](auto *ws, std::string_view message, uWS::OpCode opCode) {
            std::vector<float> pcmf32_samples = AudioLoader::generatePCMF32samples(message);

            struct whisper_state * state = whisper_init_state(ctx);

            if (whisper_full_with_state(ctx, state, wparams, pcmf32_samples.data(), pcmf32_samples.size()) != 0) {
                fprintf(stderr, "failed to process audio\n");
            }
            int n_segments = whisper_full_n_segments_from_state(state);
            std::cout << "N_Segments: " << n_segments << std::endl;
            for (int i = 0; i < n_segments; ++i) {
                const char * text = whisper_full_get_segment_text_from_state(state, i);
                ws->send(create_response(text, i == n_segments-1), uWS::OpCode::TEXT, false);
            }
            whisper_free_state(state);

        },
        .drain = [](auto */*ws*/) {
            /* Check ws->getBufferedAmount() here */
        },
        .ping = [](auto */*ws*/, std::string_view) {
            /* Not implemented yet */
        },
        .pong = [](auto */*ws*/, std::string_view) {
            /* Not implemented yet */
        },
        .close = [](auto */*ws*/, int /*code*/, std::string_view /*message*/) {
            /* You may access ws->getUserData() here */
        }
    }).listen(3000, [](auto *listen_socket) {
            if (listen_socket) {
                        std::cout << "Listening on port " << 3000 << std::endl;
            }
        }).run();

        std::cout << "Failed to listen on port 3000" << std::endl;
}