serizba / cppflow

Run TensorFlow models in C++ without installation and without Bazel
https://serizba.github.io/cppflow/
MIT License
787 stars 179 forks source link

Not able to create a vector of DT_STRING #196

Closed ns-wxin closed 2 years ago

ns-wxin commented 2 years ago

Hi,

My model looks like following:

signature_def['serving_default']:
    The given SavedModel SignatureDef contains the following input(s):
      inputs['sentences'] tensor_info:
          dtype: DT_STRING
          shape: (-1)
          name: serving_default_sentences:0
    The given SavedModel SignatureDef contains the following output(s):
      outputs['output_0'] tensor_info:
          dtype: DT_FLOAT
          shape: (-1, 50)
          name: StatefulPartitionedCall_2:0
    Method name is: tensorflow/serving/predict

It's an array of DT_STRING for input. I'm not able to use your interface to build input tensor. There's only one interface that takes a single std::string. No interface takes a vector of std::string. The following code would not work. Please advise.

string sent1 {"I enjoy taking long walks along the beach with my dog."};
auto sentence = cppflow::tensor(sent1);
cppflow::model model(modelPath);
auto output = model({{"serving_default_sentences:0", sentence}},{"StatefulPartitionedCall_2:0"});
serizba commented 2 years ago

Hi @ns-wxin

Following my comment in #200, a vector of TF_STRING could be created with this:

inline tensor::tensor(const std::vector<std::string>& values, const std::vector<int64_t>& shape) {

    auto size = values.size();
    auto deallocator = [size](TF_Tensor* tft) {
        TF_TString* tstr = static_cast<TF_TString*>(TF_TensorData(tft));
        for (int i = 0; i < size; ++i) {
            TF_TString_Dealloc(&tstr[i]);
        }
        TF_DeleteTensor(tft);
    };

    this->tf_tensor = {TF_AllocateTensor(static_cast<enum TF_DataType>(TF_STRING), shape.data(), static_cast<int>(shape.size()), values.size() * sizeof(TF_TString)), deallocator};

    TF_TString* tstr = static_cast<TF_TString*>(TF_TensorData(this->tf_tensor.get()));

    for (int i = 0; i < values.size(); ++i) {
        TF_TString_Init(&tstr[i]);
        TF_TString_Copy(&tstr[i], values[i].c_str(), values[i].length());
    }

    this->tfe_handle = {TFE_NewTensorHandle(this->tf_tensor.get(), context::get_status()), TFE_DeleteTensorHandle};
    status_check(context::get_status()); 
}