Open Azyka opened 10 months ago
The question seems to be fixed at the lastest version 0.16.dev. Personally, I think it is a problem about argmax value on the special cases. In this example, the output is
{'v4_0': array([8, 8])} {'v14_0': array([[ 1., 1.],
[nan, 1.],
[nan, nan],
[nan, nan],
[nan, nan],
[nan, nan],
[nan, nan],
[nan, nan],
[nan, nan]], dtype=float16), 'v24_0': array([8, 8])}
Node v4_0 ouputs the max values' indices of the v14_0. However, nan can't be compared with a valid number.
Expected behavior
When adding an extra torch.Tensor.transpose node as output to this graph:
New:
The outputs of torch.argmax is expected to be the same for the same input in this 2 graphs.
Actual behavior
The outputs of the 2 graphs turn out to be different after tvm_opt_4.
Environment
Steps to reproduce
Sample code:
Triage