apache / tvm

Open deep learning compiler stack for cpu, gpu and specialized accelerators
Apache License 2.0
11.42k stars 3.4k forks source link

[Bug] [Relay] [Torch] [ONNX] Robustness of `Cast` operator accepting `NaN` values #17081

Open shaoyuyoung opened 4 weeks ago

shaoyuyoung commented 4 weeks ago


Here is a single op: Cast image

In TVM, when it accepts NaN value, it outputs False.

However, in PyTorch, it outputs True.

In Pytorch and ONNX, Cast would cast the Nonzero value to False, the others to True. The evidence is here: https://onnx.ai/onnx/operators/onnx__Cast.html#l-onnx-doc-cast image

I am unsure how the Cast op is defined in TVM. But if it is different from other frameworks/compilers (e.g., Pytorch & ONNX), the final results would be inconsistent with other frameworks/compilers in complex scenarios (i.e., a model containing more ops).

Code to repro

import pickle
import torch
import torch.nn as nn
import tvm
from tvm import relay
from tvm.contrib import graph_executor
import numpy as np
import onnx
import numpy.testing as npt

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()

    def forward(self, input_tensor):
        cast_output = input_tensor.to(torch.bool)

        return cast_output

model = Model()
input_tensor = torch.tensor([float('nan')])

torch_output = model(input_tensor).numpy()

onnx_model = onnx.load("test.onnx")

target = "llvm"

shape_dict = {"input": input_tensor.shape}

mod, params = relay.frontend.from_onnx(onnx_model, shape_dict)

dev = tvm.cpu(0)
with tvm.transform.PassContext(opt_level=4):
    executor = relay.build_module.create_executor(
        "graph", mod, dev, target, params

inputs = {"input": tvm.nd.array(input_tensor.numpy())}

tvm_output = executor(**inputs).numpy()

npt.assert_allclose(torch_output, tvm_output, rtol=1e-5, atol=1e-8)

Error log

Not equal to tolerance rtol=1e-05, atol=1e-08

Mismatched elements: 1 / 1 (100%)
 x: array([ True])
 y: array([False])

Environment & Version

ubuntu 20 TVM d1ac1c0202b3d8cb2af268ce79c2ac710554152b

cc @KJlaccHoeUM9l @shingjan @yelite