yalue / onnxruntime_go

A Go (golang) library wrapping microsoft/onnxruntime.
MIT License
187 stars 34 forks source link

Model loading question #40

Closed mvestnik closed 8 months ago

mvestnik commented 8 months ago

Hello! I have read the example from README and one thing made me suspicious.

session, err := ort.NewAdvancedSession("path/to/network.onnx",
        []string{"Input 1 Name"}, []string{"Output 1 Name"},
        []ArbitraryTensor{inputTensor}, []ArbitraryTensor{outputTensor}, nil)
    defer session.Destroy()

    // Calling Run() will run the network, reading the current contents of the
    // input tensors and modifying the contents of the output tensors.
    err = session.Run()

I right that every time a model is predicted, it must be loaded EVERY time? I looked at other examples and in all of them input_tensor is passed to model.Run() If this is indeed the case, then it becomes impossible to use it in production.

Thanks

yalue commented 8 months ago

No, the model does not need to be re-loaded every time. You can modify the contents of the input tensor between every call to Run():

// Fills the given tensor with random floats.
func changeInput(inputTensor *ort.Tensor[float32]) {
    tensorData := inputTensor.GetData()
    for i := range tensorData {
        tensorData[i] = rand.Float32()
    }
}

func processOutput(outputTensor *ort.Tensor[float32]) {
    outputData := outputTensor.GetData()
    // Do whatever with the output []float32 ...
}

func main() {
    session, err := ort.NewAdvancedSession("path/to/network.onnx",
        []string{"Input 1 Name"}, []string{"Output 1 Name"},
        []ArbitraryTensor{inputTensor}, []ArbitraryTensor{outputTensor}, nil)
    if err != nil {
        // ...
    }
    defer session.Destroy()

    // Run the network 100 times, changing the input every time.
    for i := 0; i < 100; i++ {
        changeInput(inputTensor)
        err := session.Run()
        if err != nil {
            // ...
        }
        processOutput(outputTensor)
    }

    // ...
}

As for inputs being passed to the Run() function, that's possible with the DynamicAdvancedSession type. Though, once again, it doesn't require re-loading the entire model; it just allows passing a different set of input and output tensors to the Run function. Internally, it's marginally more efficient to change the data in an existing tensor, which is what the non-Dynamic session's Run() function does.