serizba / cppflow

Run TensorFlow models in C++ without installation and without Bazel
https://serizba.github.io/cppflow/
MIT License
774 stars 177 forks source link

Load model from resources #249

Open charvey2718 opened 11 months ago

charvey2718 commented 11 months ago

Many thanks for this extremely helpful library.

Currently models are loaded from a file using the model constructor which calls readGraph. Sometimes it is helpful to load a model from resources. This is an issue that I have solved for frozen graphs, and I'm posting it here as a suggestion, in case you want to add it to the master.

(I'm not that familiar with GitHub, so excuse me not branching the master and pushing, or whatever the terms are!)

I added to model.h a new constructor which takes a pointer to a std::vector of uchar as its only parameter. This then provides the arguments bufferModel->data() and bufferModel->size() to TF_NewBufferFromString instead of readGraph(filename) as in the existing version.

inline model::model(const std::vector<uchar>* bufferModel)
{
    this->status = {TF_NewStatus(), &TF_DeleteStatus};
    this->graph = {TF_NewGraph(), TF_DeleteGraph};

    // Create the session.
    std::unique_ptr<TF_SessionOptions, decltype(&TF_DeleteSessionOptions)>
        session_options = {TF_NewSessionOptions(), TF_DeleteSessionOptions};

    auto session_deleter = [this](TF_Session* sess) {
        TF_DeleteSession(sess, this->status.get());
        status_check(this->status.get());
    };

    this->session = {TF_NewSession(this->graph.get(),
            session_options.get(),
            this->status.get()),
        session_deleter};
    status_check(this->status.get());

    // Import the graph definition
    TF_Buffer* def = TF_NewBufferFromString(bufferModel->data(), bufferModel->size());
    if (def == nullptr)
    {
        throw std::runtime_error("Failed to import graph def from file");
    }

    std::unique_ptr<TF_ImportGraphDefOptions, decltype(&TF_DeleteImportGraphDefOptions)> graph_opts = {
            TF_NewImportGraphDefOptions(), TF_DeleteImportGraphDefOptions};
    TF_GraphImportGraphDef(this->graph.get(), def, graph_opts.get(), this->status.get());
    TF_DeleteBuffer(def);

    status_check(this->status.get());
}

I then load the PB model from resources as std::vector<uchar> using a LoadModel function in my own project's source code, and pass that to the new model constructor. My project happens to be using wxWidgets, and so this is conveniently done as follows. I include this here only in case this might help someone in future. It's not itself a suggestion for cppflow.

void LoadModel(wxString resName, std::vector<uchar>& model)
{
    HRSRC hrsrc = FindResource(wxGetInstance(), resName, RT_RCDATA);
    if(hrsrc == NULL) return;

    HGLOBAL hglobal = LoadResource(wxGetInstance(), hrsrc);
    if(hglobal == NULL) return;

    void *data = LockResource(hglobal);
    if(data == NULL) return;

    DWORD datalen = SizeofResource(wxGetInstance(), hrsrc);
    if(datalen < 1) return;

    uchar *charBuf = (uchar*)data;
    model = std::vector<uchar>(charBuf, charBuf + datalen);
}