ivannz / cplxmodule

Complex-valued neural networks for pytorch and Variational Dropout for real and complex layers.
MIT License
138 stars 27 forks source link

CplxConv1d can be exported to ONNX but cannot be inferred by ONNXRUNTIME #13

Closed pfeatherstone closed 3 years ago

pfeatherstone commented 3 years ago

The following example code shows that cplxmodule.nn.CplxConv1d can be exported but cannot be run.

    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.net = cplxmodule.nn.CplxConv1d(1,1,3,padding=1)

        def forward(self, x):
            x = cplxmodule.Cplx(x[..., 0], x[..., 1])
            x = self.net(x)
            x = torch.stack([x.real, x.imag], dim=-1)
            return x

    model = Net().eval()

    input = torch.randn(1, 1, 1024, 2)
    out = model(input)

    torch.onnx.export(model,
                      (input,),
                      "file.onnx",
                      opset_version=12,
                      input_names=['in'],
                      output_names=['out'])

    print("exported")

    import onnxruntime
    ort_session = onnxruntime.InferenceSession("file.onnx")
    ort_inputs = {ort_session.get_inputs()[0].name: input.numpy()}
    ort_outs = ort_session.run(None, ort_inputs)
    assert len(ort_outs) == 1, "bad number of outputs"
    np.testing.assert_allclose(out.detach().cpu(), ort_outs[0], rtol=1e-04, atol=1e-05)
    print("done")
pfeatherstone commented 3 years ago

I get error:

[E:onnxruntime:, sequential_executor.cc:318 Execute] Non-zero status code returned while running Slice node. Name:'Slice_11' Status Message: /onnxruntime_src/onnxruntime/core/providers/cpu/tensor/slice.cc:221 static void onnxruntime::SliceBase::FillVectorsFromInput(const onnxruntime::Tensor&, const onnxruntime::Tensor&, const onnxruntime::Tensor*, const onnxruntime::Tensor*, std::vector<long int>&, std::vector<long int>&, std::vector<long int>&, std::vector<long int>&) ends_tensor.Shape().NumDimensions() == 1 was false. Ends must be a 1-D array
pfeatherstone commented 3 years ago

I highly doubt this is an onnxruntime bug. But then again, the debug statements give no useful hints.

ivannz commented 3 years ago

Thank you for your catch @pfeatherstone . the issue has been fixed in pr #14