oramasearch / onnx-go

onnx-go gives the ability to import a pre-trained neural network within Go without being linked to a framework or library.
https://blog.owulveryck.info/2019/04/03/from-a-project-to-a-product-the-state-of-onnx-go.html
MIT License
704 stars 72 forks source link

Is Conv1D supported? #210

Open carbocation opened 1 year ago

carbocation commented 1 year ago

I have an onnx model that expects input in the shape (B, C, S), where S is signal. E.g., (1, 1, 350). When I try to initialize this based on the demo (code shown below), I run into the error shown beneath the code. Is Conv1D not supported, or should I modify how I am setting up the input (which is currently tensor.New(tensor.WithShape(1, 1, signalLength), tensor.Of(tensor.Float32)))?

package main

import (
    "fmt"
    "log"
    "os"

    "github.com/owulveryck/onnx-go"
    "github.com/owulveryck/onnx-go/backend/x/gorgonnx"
    "gorgonia.org/tensor"
)

var signalLength = 350

func main() {
    // Create a backend receiver
    backend := gorgonnx.NewGraph()
    // Create a model and set the execution backend
    model := onnx.NewModel(backend)

    // read the onnx model
    b, _ := os.ReadFile("../model.pt.onnx")
    // Decode it into the model
    err := model.UnmarshalBinary(b)
    if err != nil {
        log.Fatal(err)
    }

    // Set the first input, the number depends of the model
    input := tensor.New(tensor.WithShape(1, 1, signalLength), tensor.Of(tensor.Float32))
    model.SetInput(0, input)
    err = backend.Run()
    if err != nil {
        log.Fatal(err)
    }
    // Check error
    output, _ := model.GetOutputTensors()
    // write the first output to stdout
    fmt.Println(output[0])
}

Error:

go run *.go
2023/08/27 14:22:30 conv: kernel shape is supposed to have a dim of 2
exit status 1