I am unsure how the Cast op is defined in TVM. But if it is different from other frameworks/compilers (e.g., Pytorch & ONNX), the final results would be inconsistent with other frameworks/compilers in complex scenarios (i.e., a model containing more ops).
Code to repro
import pickle
import torch
import torch.nn as nn
import tvm
from tvm import relay
from tvm.contrib import graph_executor
import numpy as np
import onnx
import numpy.testing as npt
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
def forward(self, input_tensor):
cast_output = input_tensor.to(torch.bool)
return cast_output
model = Model()
input_tensor = torch.tensor([float('nan')])
torch_output = model(input_tensor).numpy()
torch.onnx.export(
model,
input_tensor,
"test.onnx",
input_names=["input"],
output_names=["output"],
opset_version=14,
do_constant_folding=True,
)
onnx_model = onnx.load("test.onnx")
target = "llvm"
shape_dict = {"input": input_tensor.shape}
mod, params = relay.frontend.from_onnx(onnx_model, shape_dict)
dev = tvm.cpu(0)
with tvm.transform.PassContext(opt_level=4):
executor = relay.build_module.create_executor(
"graph", mod, dev, target, params
).evaluate()
inputs = {"input": tvm.nd.array(input_tensor.numpy())}
tvm_output = executor(**inputs).numpy()
npt.assert_allclose(torch_output, tvm_output, rtol=1e-5, atol=1e-8)
Error log
AssertionError:
Not equal to tolerance rtol=1e-05, atol=1e-08
Mismatched elements: 1 / 1 (100%)
x: array([ True])
y: array([False])
Description
Here is a single op:![image](https://github.com/apache/tvm/assets/100203773/ff67783a-4158-4545-946c-a77da40eb245)
Cast
In TVM, when it accepts NaN value, it outputs False.
However, in
PyTorch
, it outputs True.In Pytorch and ONNX,![image](https://github.com/apache/tvm/assets/100203773/cc7a71a6-c707-48b6-a3e8-8f86ac2c06de)
Cast
would cast theNonzero value
to False, the others to True. The evidence is here: https://onnx.ai/onnx/operators/onnx__Cast.html#l-onnx-doc-castI am unsure how the
Cast
op is defined in TVM. But if it is different from other frameworks/compilers (e.g., Pytorch & ONNX), the final results would be inconsistent with other frameworks/compilers in complex scenarios (i.e., a model containing more ops).Code to repro
Error log
Environment & Version
ubuntu 20 TVM d1ac1c0202b3d8cb2af268ce79c2ac710554152b
cc @KJlaccHoeUM9l @shingjan @yelite