serizba / cppflow

Run TensorFlow models in C++ without installation and without Bazel
https://serizba.github.io/cppflow/
MIT License
785 stars 178 forks source link

cppflow::decode_base64() not working. #207

Closed edargham closed 2 years ago

edargham commented 2 years ago

I'm trying to pass a base64 encoded image as input into the model, however the API throws an exception whenever I do. There is no usage example in the docs, does anyone knows where this is going wrong? Here's the code:

const char* ClassifierModel::outputB64(const char* inputB64)
{
    try
    {
        cppflow::tensor input{ cppflow::decode_base64(std::string{ inputB64 }) }; // Throws an exception.

        input = cppflow::cast(input, TF_UINT8, TF_FLOAT);
        input = input / 255.f;
        input = cppflow::expand_dims(input, 0);
        input = cppflow::resize_bicubic(input, cppflow::tensor({ 75, 150 }), true);

        cppflow::tensor out{ clfModel(input) };
        cppflow::tensor result{ cppflow::arg_max(out, 1) };
        std::vector<long long int> data{ result.get_data<long long int>() };
        const long long int idx{ data[0] };

        return classes.at(idx).c_str();
    }
    catch (std::exception ex)
    {
        throw ex;
    }
}

Edits: Code block layout.

edargham commented 2 years ago

After a lot of trial and error here's the working version of this function:

Here is the updated version:

const char* ClassifierModel::outputB64(const char* inputB64)
{
    try
    {
        std::string b64{ inputB64 };
        cppflow::tensor input{ cppflow::decode_jpeg(cppflow::decode_base64(b64)) };

        input = cppflow::cast(input, TF_UINT8, TF_FLOAT);
        input = input / 255.f;
        input = cppflow::expand_dims(input, 0);
        input = cppflow::resize_bicubic(input, cppflow::tensor({ 75, 150 }), true);

        cppflow::tensor out{ clfModel(input) };
        cppflow::tensor result{ cppflow::arg_max(out, 1) };
        std::vector<long long int> data{ result.get_data<long long int>() };
        const long long int idx{ data[0] };

        return classes.at(idx).c_str();
    }
    catch (std::exception ex)
    {
        throw ex;
    }
}