microsoft / onnxruntime

ONNX Runtime: cross-platform, high performance ML inferencing and training accelerator
https://onnxruntime.ai
MIT License
14.73k stars 2.94k forks source link

OneHot should treat negative axis as range from last dimension #1313

Closed fdwr closed 5 years ago

fdwr commented 5 years ago

Describe the bug The Lotus CPU implementation OneHot rejects negative axes except -1 (rejects -2, -3...) which is overly strict and inconsistent with the spec. It should treat negative values like the other ONNX operators (Gather, Pad, Slice) as distance from the last dimension. Although the spec doesn't explicitly state this, it is implied by the wording, and the ONNX shape inference honors any negative number.

The validation rejects valid values here: https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/core/providers/cpu/tensor/onehot.h

if (tmp_axis < -1) {
    ORT_THROW("Value of axis is < -1");

Contrast with ONNX shape inference: https://github.com/onnx/onnx/blob/master/onnx/defs/tensor/defs.cc

    if (axis < -out_rank || axis >= out_rank) {
        fail_shape_inference(
            "'axis' must be in [-rank(indices)-1, rank(indices)]");
        }
        if (axis < 0) {
            axis += out_rank;
        }

System information

To Reproduce Pass axis = -3. e.g.:

 // 2D to 3D with negative axis.
{
  "op_type": "OneHot",
  "indices": [[1, 0, 3],
              [0, 2, 0]],
  "depth": [4],
  "values": [1, 2],
  "output": [[[1, 2, 1],[2, 1, 2]],
             [[2, 1, 1],[1, 1, 1]],
             [[1, 1, 1],[1, 2, 1]],
             [[1, 1, 2],[1, 1, 1]]],
  "axis": -3, // Equivalent to 0.
  "T1": "uint32",
  "T2": "uint32",
  "T3": "float32"
}`

Expected behavior Given 2D input (3D output), axis = -3 is treated as axis = 0, or axis = -2 is treated as axis = 1.

hariharans29 commented 5 years ago

Hi @fdwr

Don't you mean:

For 4D input, the output would be 5D and axis = -3 should be treated as -3 + 5 = 2 and axis = -2 should be treated as -2 + 5 = 3 ?

(Or)

continuing from the illustrative example, I think you wanted to say "Given 2D" (not "Given 4D")

fdwr commented 5 years ago

@hariharans29 Indeed :). Updated typo in-place.