microsoft / onnxruntime

ONNX Runtime: cross-platform, high performance ML inferencing and training accelerator
https://onnxruntime.ai
MIT License
14.59k stars 2.92k forks source link

Resize with mode linear always produces 0.5 on GPU regardless of the input #12091

Open lazycal opened 2 years ago

lazycal commented 2 years ago

Describe the bug In a model with Linear layer followed by a trilinear resize like the graph below, the result on GPU is always 0.5 for any inputs, which is different than the result on CPU and than the result from PyTorch, while the latter two equal to each other.

image The corresponding torch code:

class Model(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.linear = torch.nn.Linear(1, 1)

    @torch.no_grad()
    def forward(self, x):
        x = self.linear(x)
        x = torch.nn.functional.interpolate(
            x, size=[511, 1, 1], mode='trilinear')
        return x

Maybe related to https://github.com/microsoft/onnxruntime/issues/12019? cc the participants there @diyessi @hariharans29. Though I don't have problems on 4D tensor (i.e., bilinear), and mine is on GPU but that one seems to be on CPU? nearest mode appears to be fine too for me. After removing the Linear node the problem also disappears.

Urgency None

System information

To Reproduce Run this code

import torch
from onnx import checker
import onnxruntime as ort
import numpy as np

class Model(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.linear = torch.nn.Linear(1, 1)

    @torch.no_grad()
    def forward(self, x):
        x = self.linear(x)
        x = torch.nn.functional.interpolate(
            x, size=[511, 1, 1], mode='trilinear')
        return x

x = torch.randn(1, 1, 1, 1, 1).to(torch.float32) * 100
model = Model()
model.eval()
torch.onnx.export(model, (x,), "output.onnx",
                  input_names=["x"], output_names=["y"], opset_version=14)
checker.check_model("output.onnx", full_check=True)
print('model checked')
b_tch = model(x)

sess_opt = ort.SessionOptions()
sess_opt.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL
sess = ort.InferenceSession("output.onnx", sess_opt, providers=[
                            "CPUExecutionProvider"])
b_ort_cpu = sess.run(["y"], {"x": x.numpy()})[0]
np.testing.assert_allclose(
    b_ort_cpu, b_tch, err_msg="ort_cpu vs torch", atol=1e-2, rtol=1e-2)
print('-------------> ort_cpu is consistent to pytorch')

sess_opt = ort.SessionOptions()
sess_opt.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL
sess = ort.InferenceSession("output.onnx", sess_opt, providers=[
                            "CUDAExecutionProvider",
                            "CPUExecutionProvider"])
b_ort_gpu = sess.run(["y"], {"x": x.numpy()})[0]
np.testing.assert_allclose(
    b_ort_gpu, b_tch, err_msg="ort_gpu vs torch", atol=1e-2, rtol=1e-2)
print('pass')

ONNX model: model.onnx.zip

Expected behavior Generate consistent result.

Screenshots image

Additional context None

UPDATE Below I pasted in the code in pure ONNX without using PyTorch, as PyTorch may have bugs in resize related nodes. The issues still remains.

import numpy as np
import onnxruntime as ort
import onnx
from onnx import helper, checker
from onnx import TensorProto

ash = [1, 1, 1, 1, 1]
bsh = [1, 1, 511, 1, 3]

a = helper.make_tensor_value_info('a', TensorProto.FLOAT, ash)
b = helper.make_tensor_value_info('b', TensorProto.FLOAT, bsh)
sizes = np.array(bsh, dtype=np.int64)
w = np.random.rand(1, 1).astype(np.float32)
bias = np.random.rand(1).astype(np.float32)

node_matmul = onnx.helper.make_node(
    'MatMul',
    inputs=['a', 'w'],
    outputs=['t0'],
)

node_add = onnx.helper.make_node(
    'Add',
    inputs=['t0', 'bias'],
    outputs=['t'],
)

node = onnx.helper.make_node(
    'Resize',
    inputs=['t', '', '', 'sizes'],
    outputs=['b'],
    mode='linear',
    coordinate_transformation_mode='half_pixel'
)

graph_def = helper.make_graph(
    [node_matmul, node_add, node],
    'test-model',
    [a],
    outputs=[b],
    initializer=[
        helper.make_tensor('sizes', TensorProto.INT64, [5], sizes),
        helper.make_tensor('w', TensorProto.FLOAT, [1, 1], w),
        helper.make_tensor('bias', TensorProto.FLOAT, [1], bias)
    ]
)
model_def = helper.make_model(
    graph_def, producer_name='onnx-example')

print('The model is:\n{}'.format(model_def))
checker.check_model(model_def, full_check=True)
print('The model is checked!')

onnx.save(model_def, 'output.onnx')

x = np.random.randn(1, 1, 1, 1, 1).astype(np.float32)
sess_opt = ort.SessionOptions()
sess_opt.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL
sess = ort.InferenceSession("output.onnx", sess_opt, providers=[
                            "CPUExecutionProvider"])
b_ort_cpu = sess.run(["b"], {"a": x})[0]

sess_opt = ort.SessionOptions()
sess_opt.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL
sess = ort.InferenceSession("output.onnx", sess_opt, providers=[
                            "CUDAExecutionProvider",
                            "CPUExecutionProvider"])
b_ort_gpu = sess.run(["b"], {"a": x})[0]
np.testing.assert_allclose(
    b_ort_gpu, b_ort_cpu, err_msg="ort_gpu vs torch", atol=1e-2, rtol=1e-2)
print('pass')

UPDATE 2 So nearest mode is also problematic. See this model: [model.onnx.zip]

Use this code to reproduce:

UPDATE 3 After looking it closely it seems like a separate issue, reported in https://github.com/microsoft/onnxruntime/issues/12098

ytaous commented 2 years ago

@hariharans29 @yihonglyu - any comment/feedback is more than welcome.

shgoyal33 commented 1 year ago

@lazycal I tried recreating the onnx flow using your code but I'm not getting the output from the Resize block when I look at it in netron. Would really appreciate your help here.

lazycal commented 1 year ago

@shgoyal33 Do you mean that you did not get the 2-branch structure like below image from my pure ONNX code? That kind of structure is an artifact because of PyTorch's conversion and is not related to this issue. On my end my ONNX code procudes this graph below with the same error: image image

The onnx model: output.onnx 2.zip

shgoyal33 commented 1 year ago

@lazycal Hi, actually the problem which I'm facing is slightly different but your code and the graph helped a lot and the problem is as follows. When I tried to run your code and open the onnx file in netron. The output I'm getting is this. image

According to your code and the image of the graph you share you are getting the final shape in Resize Layer but when I go to the properties I'm getting [unknown_dim_0,unknown_dim_1,unknown_dim_2,unknown_dim_3] but in your onnx file the output from the resize block is [1,1,511,1,1] and the graph looks like this. image

I want the dimension to be displayed between the Resize and output y. I ran the code which you gave so what else did you change in your code to generate the output dimension from the Resize Block?

lazycal commented 1 year ago

@shgoyal33 No I did not do anything else. I guess it's because of the PyTorch version difference. I am using this version: "1.13.0a0+git018d071". I forgot how I installed it. Could be built-from-source. Anyway the other code snippet is pure-ONNX which does not use PyTorch and is able to reproduce the same issue. Maybe you could use that instead.

frankmanbb commented 8 months ago

any update on this issue, I encoutered the same issue when add a resize layer in pytorch:

    image = F.interpolate(image.unsqueeze(0), size=(self.resize_width_height[1], 
                        self.resize_width_height[0]),
                        mode='bilinear',
                        recompute_scale_factor=False,
                            align_corners=False)

the result onnx model has issue in inference time. cpu inference is totally ok, gpu inference has 0.5 output ( unless the input size is same as target resize size, in which case, I guess the code just copy the input and avoid the all 0.5 output).

when I change mode='bicubic', everything works well

I guess there is a bug in gpu implementation of Resize layer.

frankmanbb commented 8 months ago

update: change image type from uint8 to float32 by calling image = image.float() can also solve the issue. so apparently, the uint8 image Resize under cuda has some bug.