microsoft / onnxruntime

ONNX Runtime: cross-platform, high performance ML inferencing and training accelerator
https://onnxruntime.ai
MIT License
14.14k stars 2.86k forks source link

How to use FillStringTensor() to fill 4D strings? #10363

Closed lifang-zhang closed 2 years ago

lifang-zhang commented 2 years ago

I am tring to use onnxruntime to run inference on model bidaf-9.onnx which has input of type string[c, 1, 1, 16], I don't know how to prepare the input for that.

OrtStatus *retStatus;
int64_t shape[4] = {10, 1, 1, 16};
OrtValue* input_tensor;
retStatus = g_ort->CreateTensorAsOrtValue(allocator, shape, 4, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING, &input_tensor);
if (check_ortstatus(retStatus, g_ort) || !input_tensor) {
    printf("\n CreateTensorWithDataAsOrtValue() failed\n\n");
    return;
}

const char context_char[10][16] = {
        "A",
        "quick",
        "fox",
        "jumps",
        "over",
        "the",
        "lazy",
        "dog",
        "."
};
retStatus = g_ort->FillStringTensor(input_tensor, (const char * const*)context_char, 10);
if (check_ortstatus(retStatus, g_ort)) {
    printf("\n FillStringTensor() failed\n\n");
    return;
}

The above code gave me input array doesn't equal tensor size.

const char context_char[10][1][1][1][16] = {0}; gave me the same error.

I read the FillStringTensor() carefuly, which says s An array of strings. for the second parameter, does this mean that it only supports 2D strings?

lifang-zhang commented 2 years ago

Oh, the third parameter for FillStringTensor() should be 160 which is the number of strings and should equal to tensor shape size of {10, 1, 1, 16}. Following code works.

const char* context_char[10][1][1][16] = {
    {{{"A", "\0", "\0", "\0", "\0", "\0", "\0", "\0", "\0", "\0", "\0", "\0", "\0", "\0", "\0", "\0"}}},
    {{{"q", "u", "i", "c", "k", "\0", "\0", "\0", "\0", "\0", "\0", "\0", "\0", "\0", "\0", "\0"}}},
    {{{"b", "r", "o", "w", "n", "\0", "\0", "\0", "\0", "\0", "\0", "\0", "\0", "\0","\0", "\0"}}},
    {{{"f", "o", "x", "\0", "\0", "\0", "\0", "\0", "\0", "\0", "\0", "\0", "\0", "\0","\0", "\0"}}},
    {{{"j", "u", "m", "p", "s", "\0", "\0", "\0", "\0", "\0", "\0", "\0", "\0", "\0","\0", "\0"}}},
    {{{"o", "v", "e", "r", "\0", "\0", "\0", "\0", "\0", "\0", "\0", "\0", "\0", "\0","\0", "\0"}}},
    {{{"t", "h", "e", "\0", "\0", "\0", "\0", "\0", "\0", "\0", "\0", "\0", "\0", "\0","\0", "\0"}}},
    {{{"l", "a", "z", "y", "\0", "\0", "\0", "\0", "\0", "\0", "\0", "\0", "\0", "\0","\0", "\0"}}},
    {{{"d", "o", "g", "\0", "\0", "\0", "\0", "\0", "\0", "\0", "\0", "\0", "\0", "\0","\0", "\0"}}},
    {{{".", "\0", "\0", "\0", "\0", "\0", "\0", "\0", "\0", "\0", "\0", "\0", "\0", "\0", "\0","\0"}}}
};

retStatus = g_ort->FillStringTensor(input_tensor, (const char * const*)context_char, 160);
if (check_ortstatus(retStatus, g_ort)) {
    printf("\n FillStringTensor() failed\n\n");
    return;
}