flashlight / wav2letter

Facebook AI Research's Automatic Speech Recognition Toolkit
https://github.com/facebookresearch/wav2letter/wiki
Other
6.39k stars 1.01k forks source link

Decoder freezes on concurrent requests #440

Closed yutkin closed 3 years ago

yutkin commented 5 years ago

I'm trying to embed decoder in gRPC service and use it simultaneously in concurrent requests. All works fine when I have a sequential version (i.e. only one request per time). But when I try to run each decoding request in a separate thread, network freezes on the forward pass.

gRPC server

server.h

#include "model.h"

class AsrAsync final {
public:
    AsrAsync(std::shared_ptr<AsrModel> asrModel, uint64 threadPoolSize): model(asrModel), threadPool(threadPoolSize) {}

    ~AsrAsync() {
        server_->Shutdown();
        cq_->Shutdown();
    }

    void Run(const std::string& addr) {
        ServerBuilder builder;
        builder.AddListeningPort(addr, grpc::InsecureServerCredentials());
        builder.RegisterService(&service_);
        cq_ = builder.AddCompletionQueue();
        server_ = builder.BuildAndStart();

        HandleRpcs();
    }

private:
    class CallData {
    public:
        Status asyncHandle(const AsrRequest& request_, AsrResult& reply_, ServerContext& ctx_);

        CallData(asr::AsrServicer::AsyncService* service, ServerCompletionQueue* cq, std::shared_ptr<AsrModel> asrModel)
                : service_(service), cq_(cq), responder_(&ctx_), status_(CREATE), model(asrModel) {
            Proceed();
        }

        void Proceed() {
            if (status_ == CREATE) {
                status_ = PROCESS;

                service_->RequestRecognizeSpeech(&ctx_, &request_, &responder_, cq_, cq_, this);
            } else if (status_ == PROCESS) {

                new CallData(service_, cq_, this->model);

                auto status = asyncHandle(request_, reply_, ctx_);

                status_ = FINISH;

                responder_.Finish(reply_, status, this);
            } else {
                GPR_ASSERT(status_ == FINISH);
                delete this;
            }
        }

    private:
        asr::AsrServicer::AsyncService* service_;
        ServerCompletionQueue* cq_;
        ServerContext ctx_;

        AsrRequest request_;
        AsrResult reply_;

        ServerAsyncResponseWriter<AsrResult> responder_;

        std::shared_ptr<AsrModel> model;

        enum CallStatus { CREATE, PROCESS, FINISH };
        CallStatus status_;  // The current serving state.
    };

    void HandleRpcs() {
        new CallData(&service_, cq_.get(), this->model);
        void* tag;  // uniquely identifies a request.
        bool ok;

        while (cq_->Next(&tag, &ok)) {
            GPR_ASSERT(ok);

            // Run request handler in a thread pool
            threadPool.enqueue(std::bind(&CallData::Proceed, static_cast<CallData*>(tag)));

            // Uncomment for use only in one thread. This works!!!
            // static_cast<CallData*>(tag)->Proceed();
        }
    }

private:
    std::unique_ptr<grpc::ServerCompletionQueue> cq_;
    asr::AsrServicer::AsyncService service_;
    std::unique_ptr<Server> server_;

    std::shared_ptr<AsrModel> model;

    fl::ThreadPool threadPool;
};

server.cc

#include "server.h"

Status AsrAsync::CallData::asyncHandle(const AsrRequest &request_, AsrResult &reply_, ServerContext &ctx_)
{
    RepeatedPtrField<std::string> urls = request_.url();

    // Downloading and saving of URLs ...

    auto result = this->model->decode();
    reply_.mutable_text()->Reserve(urls.size());

    for (int i = 0; i < urls.size(); ++i)
    {
            reply_.add_text();
            reply_.set_text(i, result[i]);
    }

    return Status::OK;
}

Decoder

model.h

class AsrModel {
public:
    AsrModel(int argc, char *argv[]);
    ~AsrModel() {}
    std::vector<std::string> decode();
    void runDecoder(int tid, int start, int end, const w2l::EmissionSet& emissionSet, std::vector<int>& shuffle_order, std::vector<std::string>& result);

private:
    std::mutex decodeMutex;

    std::shared_ptr<fl::Module> network_;
    std::shared_ptr<w2l::LM> lm_;
    std::shared_ptr<w2l::Trie> trie_;

    w2l::DecoderOptions decoderOpt_;
    int unkWordIdx_;
    int blankIdx_;
    int silIdx_;

    w2l::Dictionary tokenDict_;
    w2l::LexiconMap lexicon_;
    w2l::Dictionary wordDict_;
    w2l::DictionaryMap dicts_;
};

model.cc

#include "model.h"

using namespace std::placeholders;

AsrModel::AsrModel(int argc, char *argv[])
{
    auto flagsfile = w2l::FLAGS_flagsfile;
    if (!flagsfile.empty())
    {
        LOG(INFO) << "Reading flags from file " << flagsfile;
        gflags::ReadFromFlagsFile(flagsfile, argv[0], true);
        // Re-parse command line flags to override values in the flag file.
        gflags::ParseCommandLineFlags(&argc, &argv, false);
    }

    std::unordered_map<std::string, std::string> cfg;

    if (!w2l::FLAGS_am.empty())
    {
        LOG(INFO) << "[Network] Reading acoustic model from " << w2l::FLAGS_am;
        af::setDevice(0);
        w2l::W2lSerializer::load(w2l::FLAGS_am, cfg, this->network_);
        this->network_->eval();
        DLOG(INFO) << "[Network] " << this->network_->prettyString();
        LOG(INFO) << "[Network] Number of params: " << w2l::numTotalParams(this->network_);
    }
    else
    {
        LOG(FATAL) << "[Network] Fail to load network. Flag --am is not provided.";
    }

    auto flags = cfg.find(w2l::kGflags);
    if (flags == cfg.end())
    {
        LOG(FATAL) << "[Network] Invalid config loaded from " << w2l::FLAGS_am;
    }
    LOG(INFO) << "[Network] Updating flags from config file: " << w2l::FLAGS_am;
    gflags::ReadFlagsFromString(flags->second, gflags::GetArgv0(), true);

    // override with user-specified flags
    gflags::ParseCommandLineFlags(&argc, &argv, false);
    if (!flagsfile.empty())
    {
        gflags::ReadFromFlagsFile(flagsfile, argv[0], true);
        // Re-parse command line flags to override values in the flag file.
        gflags::ParseCommandLineFlags(&argc, &argv, false);
    }

    this->decoderOpt_ = w2l::DecoderOptions(
        w2l::FLAGS_beamsize, static_cast<float>(w2l::FLAGS_beamthreshold), static_cast<float>(w2l::FLAGS_lmweight),
        static_cast<float>(w2l::FLAGS_wordscore), static_cast<float>(w2l::FLAGS_unkweight), w2l::FLAGS_logadd,
        static_cast<float>(w2l::FLAGS_silweight), w2l::CriterionType::CTC);

    LOG(INFO) << "Reading lexicon and token dictionary";

    w2l::Dictionary tokenDict(w2l::pathsConcat(w2l::FLAGS_tokensdir, w2l::FLAGS_tokens));

    for (int64_t r = 1; r <= w2l::FLAGS_replabel; ++r)
    {
        tokenDict.addEntry(std::to_string(r));
    }
    tokenDict.addEntry(w2l::kBlankToken);
    this->tokenDict_ = tokenDict;

    this->lexicon_ = w2l::loadWords(w2l::FLAGS_lexicon, w2l::FLAGS_maxword);
    this->wordDict_ = w2l::createWordDict(this->lexicon_);
    this->dicts_ = {{w2l::kTargetIdx, this->tokenDict_}, {w2l::kWordIdx, this->wordDict_}};

    this->unkWordIdx_ = this->wordDict_.getIndex(w2l::kUnkToken);
    this->blankIdx_ = w2l::FLAGS_criterion == w2l::kCtcCriterion ? this->tokenDict_.getIndex(w2l::kBlankToken) : -1;
    this->silIdx_ = this->tokenDict_.getIndex(w2l::FLAGS_wordseparator);

    LOG(INFO) << "Building a language model";

    this->lm_ = std::make_shared<w2l::KenLM>(w2l::FLAGS_lm, this->wordDict_);
    if (!this->lm_) {
        LOG(FATAL) << "[LM constructing] Failed to load LM: " << w2l::FLAGS_lm;
    }
    LOG(INFO) << "LM constructed";

    int silIdx = this->tokenDict_.getIndex(w2l::FLAGS_wordseparator);
    this->trie_ = std::make_shared<w2l::Trie>(this->tokenDict_.indexSize(), silIdx);

    auto startState = this->lm_->start(false);

    for (auto &it : this->lexicon_)
    {
        const std::string &word = it.first;
        int usrIdx = this->wordDict_.getIndex(word);
        float score = -1;
        w2l::LMStatePtr dummyState;
        std::tie(dummyState, score) = this->lm_->score(startState, usrIdx);

        for (auto &tokens : it.second)
        {
            auto tokensTensor = tkn2Idx(tokens, this->tokenDict_, w2l::FLAGS_replabel);
            this->trie_->insert(tokensTensor, usrIdx, score);
        }
    }
    this->trie_->smear(w2l::SmearingMode::MAX);
    LOG(INFO) << "[Decoder] Trie planted";
}

void repair_order(std::vector<std::string> &result, const std::vector<int> &shuffle_order)
{
    for (int i = 0; i < result.size(); ++i)
    {
        std::swap(result[i], result[shuffle_order[i]]);
    }
}

void AsrModel::runDecoder(int tid, int start, int end, const w2l::EmissionSet &emissionSet,
                          std::vector<int> &shuffle_order, std::vector<std::string> &result)
{
    try
    {
        auto decoder = std::make_unique<w2l::WordLMDecoder>(this->decoderOpt_, this->trie_, this->lm_, this->silIdx_,
                                                            this->blankIdx_, this->unkWordIdx_, emissionSet.transition);

        for (int s = start; s < end; s++)
        {
            auto emission = emissionSet.emissions[s];
            auto sampleId = emissionSet.sampleIds[s];
            auto T = emissionSet.emissionT[s];
            auto N = emissionSet.emissionN;
            auto results = decoder->decode(emission.data(), T, N);
            auto &rawWordPrediction = results[0].words;

            rawWordPrediction = w2l::validateIdx(rawWordPrediction, this->wordDict_.getIndex(w2l::kUnkToken));
            auto wordPrediction = wrdIdx2Wrd(rawWordPrediction, this->wordDict_);

            std::lock_guard<std::mutex> guard(this->decodeMutex);

            auto wordPredictionStr = w2l::join(" ", wordPrediction);
            shuffle_order.push_back(std::stoi(sampleId) - 1);
            result.push_back(wordPredictionStr);
        }
    }
    catch (const std::exception &exc)
    {
        LOG(FATAL) << "Exception in thread " << tid << "\n" << exc.what();
    }
}

std::vector<std::string> AsrModel::decode()
{
    auto decodePreparationTimer = fl::TimeMeter();
    decodePreparationTimer.resume();

    int worldRank = 0;
    int worldSize = 1;
    int batchSize = 1;

    auto ds = w2l::createDataset(w2l::FLAGS_test, this->dicts_, this->lexicon_, batchSize, worldRank, worldSize);
    int cnt = 0;
    w2l::EmissionSet emissionSet;

    LOG(INFO) << "[Serialization] Running forward pass ...";

    for (auto &sample : *ds)
    {
        af::print("Input array", sample[w2l::kInputIdx]);
        auto rawEmission = this->network_->forward({fl::input(sample[w2l::kInputIdx])}).front();
        int N = rawEmission.dims(0);
        int T = rawEmission.dims(1);

        auto emission = w2l::afToVector<float>(rawEmission);
        auto tokenTarget = w2l::afToVector<int>(sample[w2l::kTargetIdx]);
        auto wordTarget = w2l::afToVector<int>(sample[w2l::kWordIdx]);

        std::vector<std::string> wordTargetStr = w2l::wrdIdx2Wrd(wordTarget, this->wordDict_);

        emissionSet.emissions.emplace_back(emission);
        emissionSet.wordTargets.emplace_back(wordTargetStr);
        emissionSet.tokenTargets.emplace_back(tokenTarget);
        emissionSet.emissionT.emplace_back(T);
        emissionSet.emissionN = N;
        emissionSet.sampleIds.emplace_back(w2l::readSampleIds(sample[w2l::kSampleIdx]).front());

        ++cnt;
        if (cnt == w2l::FLAGS_maxload)
        {
            break;
        }
    }

    int nSample = emissionSet.emissions.size();

    auto timeForDecodePreparation = static_cast<float>(decodePreparationTimer.value());
    LOG(INFO) << "Total time spent for decoding preparation: " << timeForDecodePreparation
              << " sec. Per sample: " << timeForDecodePreparation / nSample << " sec.";

    LOG(INFO) << "Decoding begin";

    // Decoding
    std::vector<int> shuffle_order(nSample);
    std::vector<std::string> result(nSample);

    int nSamplePerThread = std::ceil(nSample / static_cast<float>(w2l::FLAGS_nthread_decoder));
    LOG(INFO) << "[Dataset] Number of samples per thread: " << nSamplePerThread;

    af::deviceGC();

    auto startThreads = [&]() {
        auto decodeFunc = std::bind(&AsrModel::runDecoder, this, _1, _2, _3, std::cref(emissionSet),
                                    std::ref(shuffle_order), std::ref(result));

        if (w2l::FLAGS_nthread_decoder == 1)
        {
            decodeFunc(0, 0, nSample);
        }
        else if (w2l::FLAGS_nthread_decoder > 1)
        {
            fl::ThreadPool threadPool(w2l::FLAGS_nthread_decoder);
            for (int i = 0; i < w2l::FLAGS_nthread_decoder; i++)
            {
                int start = i * nSamplePerThread;
                if (start >= nSample)
                {
                    break;
                }
                int end = std::min((i + 1) * nSamplePerThread, nSample);
                threadPool.enqueue(decodeFunc, i, start, end);
            }
        }
    };

    auto timer = fl::TimeMeter();
    timer.resume();
    startThreads();
    repair_order(result, shuffle_order);
    auto timeForDecode = static_cast<float>(timer.value());
    LOG(INFO) << "Total time spent for decoding: " << timeForDecode
              << " sec. Per sample: " << timeForDecode / result.size() << " sec.";

    return result;
}

Example output:

root@f2822198be7d:/app# ./asr --flagsfile configs/decode.cfg &
[1] 2377
root@f2822198be7d:/app# I1103 13:16:11.821629  2377 main.cpp:31]
ArrayFire v3.6.4 (CUDA, 64-bit Linux, build 1b8030c5)
Platform: CUDA Toolkit 10.1, Driver: 430.14
[0] GeForce GTX TITAN X, 12213 MB, CUDA Compute 5.2
-1- GeForce GTX TITAN X, 12213 MB, CUDA Compute 5.2
I1103 13:16:11.858422  2377 main.cpp:33] gRPC thread pool size=8
I1103 13:16:11.858474  2377 model.cc:14] Reading flags from file configs/decode.cfg
I1103 13:16:11.858712  2377 model.cc:24] [Network] Reading acoustic model from configs/am.binary
I1103 13:16:12.406680  2377 model.cc:29] [Network] Number of params: 3920452
I1103 13:16:12.406729  2377 model.cc:41] [Network] Updating flags from config file: configs/am.binary
I1103 13:16:12.407167  2377 model.cc:58] Reading lexicon and token dictionary
I1103 13:16:12.781265  2377 model.cc:77] Building a language model
I1103 13:16:12.821374  2377 model.cc:83] LM constructed
I1103 13:16:13.379698  2377 model.cc:105] [Decoder] Trie planted
I1103 13:16:13.380503  2377 main.cpp:43] Service started at 0.0.0.0:5000

root@f2822198be7d:/app# ./test_client

I1103 13:16:20.003556  2381 W2lListFilesDataset.cpp:113] 100 files found.
I1103 13:16:20.003609  2381 Utils.cpp:98] Filtered 0/100 samples
I1103 13:16:20.003651  2381 W2lListFilesDataset.cpp:45] Total batches (i.e. iters): 100
I1103 13:16:20.003675  2381 model.cc:162] [Serialization] Running forward pass ...
Input array
[411 40 1 1]

alt text

At this stage process is freezes. If run one-threaded version, array prints normally and forward and decoding pass successfully.

Why does execution freeze in a multri-threaded version?

lunixbochs commented 5 years ago

You should probably just use multiple processes if you want parallel executions. I don't think it's intended for this to be thread safe.

yutkin commented 5 years ago

If needed, thread safety can be reached by the protection of critical sections. Actually, in my example, it handles only one request, there are no concurrent requests. It fails only when I run "forward" in a separate thread.

lunixbochs commented 5 years ago

FWIW I run this stack against MKL instead of CUDA just fine with resources created and consumed in different threads (but only one thread using resources at a time)

You should attach a debugger and print the thread stacks to see where it's frozen.

lunixbochs commented 5 years ago

Also, try this for your forward pass instead of your W2lDataset code: https://github.com/talonvoice/wav2letter/blob/decoder/w2l_forward.cpp#L64

The W2lDataset code does a lot of gnarly stuff internally, for example here's a threadpool: https://github.com/facebookresearch/wav2letter/blob/master/src/data/W2lDataset.cpp#L30

yutkin commented 5 years ago

It is very difficult to embed process-based parallelism in our current app, therefore I need to find a solution for using the network from different threads.

yutkin commented 5 years ago

https://github.com/facebookresearch/flashlight/issues/73 related

tlikhomanenko commented 3 years ago

closing due to inactivity and huge improvements in our codebase since that time.