oramasearch / onnx-go

onnx-go gives the ability to import a pre-trained neural network within Go without being linked to a framework or library.
https://blog.owulveryck.info/2019/04/03/from-a-project-to-a-product-the-state-of-onnx-go.html
MIT License
720 stars 71 forks source link

Question: unsqueeze: axes in not an []int64 #200

Open loeffel-io opened 2 years ago

loeffel-io commented 2 years ago

Hello @owulveryck,

i try to run this onnx model:

model

with this code:

package main

import (
    "fmt"
    "github.com/owulveryck/onnx-go"
    "github.com/owulveryck/onnx-go/backend/x/gorgonnx"
    "gorgonia.org/tensor"
    "io/ioutil"
    "log"
)

func main() {
    backend := gorgonnx.NewGraph()
    model := onnx.NewModel(backend)

    var err error
    var b []byte
    var output []tensor.Tensor

    if b, err = ioutil.ReadFile("model.onnx"); err != nil {
        log.Fatal(err)
    }

    if err = model.UnmarshalBinary(b); err != nil {
        log.Fatal(err)
    }

    var acosGroupTensor tensor.Tensor
    if acosGroupTensor, err = tensor.Argmax(
        tensor.New(tensor.WithShape(1, 5), tensor.Of(tensor.Float32), tensor.WithBacking([]float32{0, 1, 0, 0, 0})),
        1,
    ); err != nil {
        log.Fatal(err)
    }

    if err = model.SetInput(0, acosGroupTensor); err != nil {
        log.Fatal(err)
    }

    var acosRatioTensor = tensor.New(tensor.WithShape(1, 4), tensor.Of(tensor.Float32), tensor.WithBacking([]float32{0, 0, -1, 0}))

    if err = model.SetInput(1, acosRatioTensor); err != nil {
        log.Fatal(err)
    }

    var salesRatioTensor = tensor.New(tensor.WithShape(1, 4), tensor.Of(tensor.Float32), tensor.WithBacking([]float32{0, 0, -1, 0}))

    if err = model.SetInput(2, salesRatioTensor); err != nil {
        log.Fatal(err)
    }

    if err = backend.Run(); err != nil {
        log.Fatal(err)
    }

    if output, err = model.GetOutputTensors(); err != nil {
        log.Fatal(err)
    }
    // write the first output to stdout
    fmt.Println(output[0])
}

It looks like i miss something pretty small - would love to get some help ❤️ thank you

loeffel-io commented 2 years ago

ref: https://github.com/owulveryck/onnx-go/blob/a1116f3e59d41e088d88670cd1e0683709882806/backend/x/gorgonnx/unsqueeze.go#L57

owulveryck commented 2 years ago

My guess is that the axes is a int64 and not a []int64 which causes the error.

I don't know why, but you can try to change the code in unsqueeze.go like this:


    a, ok := o.Attributes["axes"].(int64)
    if !ok {
        return errors.New("unsqueeze: axes in not an []int64")
    }
        a.Axes = []int64{a}
loeffel-io commented 2 years ago

Thanks @owulveryck, i created a workaround like this too but then i got stuck because the Cast operator is unimplemented.

I moved to https://github.com/triton-inference-server/server

I would leave this open to inform that the Unsqueeze operator seams unstable - feel free to close if not