microsoft / onnxruntime

ONNX Runtime: cross-platform, high performance ML inferencing and training accelerator
https://onnxruntime.ai
MIT License
14.26k stars 2.87k forks source link

Index put loop model regression with ort==1.18 #20855

Open titaiwangms opened 4 months ago

titaiwangms commented 4 months ago

Describe the issue

The error is only raised after 1.18. I tried 1.17.3, and it works fine.

onnxruntime.capi.onnxruntime_pybind11_state.InvalidArgument: [ONNXRuntimeError] : 2 : INVALID_ARGUMENT : Non-zero status code returned while running Loop node. Name:'/Loop' Status Message: Non-zero status code returned while running ScatterND node. Name:'/ScatterND_10' Status Message: invalid indice found, indice = 8

To reproduce

(1) With the uploaded ONNX file test_index_put_loop.zip

onnx_model = onnx.load("test_index_put_loop.onnx")
ort_session = onnxruntime.InferenceSession(onnx_model.SerializeToString(), providers=["CPUExecutionProvider"])

onnxruntime_input = {
    k.name: v.numpy(force=True)
    for k, v in zip(ort_session.get_inputs(), [y])
}
ort_session.run(None, onnxruntime_input)

(2) From PyTorch

import torch
import onnx
import onnxruntime

@torch.jit.script
def ngram_attention_bias(
    sequence_length: int, ngram: int, device: torch.device, dtype: torch.dtype
):
    bias = torch.ones(
        (ngram, sequence_length), device=device, dtype=dtype
    ) * float("-inf")
    for stream_idx in range(ngram):
        for i in range(sequence_length):
            bias = bias * 2
            bias[stream_idx, i] = 5
            bias = bias * 5
            bias[0, 0] = 5

    for stream_idx in range(ngram):
        for i in range(sequence_length):
            bias[stream_idx, i] = 5
            bias[0, i] = 5
    return bias

class ScriptModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.ngram = 2
        self.max_target_positions = 512

    def forward(self, hidden_states):
        seq_length, batch_size = hidden_states.shape[:2]
        predict_causal_mask = ngram_attention_bias(
            self.max_target_positions,
            self.ngram,
            hidden_states.device,
            hidden_states.dtype,
        )
        predict_causal_mask = predict_causal_mask[:, :seq_length]
        return predict_causal_mask

x = torch.randn(6, 2)
y = torch.randn(4, 1)
torch.onnx.export(
    torch.jit.script(ScriptModel()),
    x,
    "test_index_put_loop.onnx",
    input_names=["x"],
    dynamic_axes={"x": {0: "seq_length", 1: "batch_size"}},
)

onnx_model = onnx.load("test_index_put_loop.onnx")
ort_session = onnxruntime.InferenceSession(onnx_model.SerializeToString(), providers=["CPUExecutionProvider"])

onnxruntime_input = {
    k.name: v.numpy(force=True)
    for k, v in zip(ort_session.get_inputs(), [y])
}
ort_session.run(None, onnxruntime_input)

Urgency

Ths is spotted in PyTorch converter test case.

Platform

Linux

OS Version

VERSION="2.0.20240301" MARINER

ONNX Runtime Installation

Built from Source

ONNX Runtime Version or Commit ID

798cea2350a196a67ff7e0621ea125c7f2035f7c

ONNX Runtime API

Python

Architecture

X64

Execution Provider

Default CPU

Execution Provider Library Version

No response

github-actions[bot] commented 3 months ago

This issue has been automatically marked as stale due to inactivity and will be closed in 30 days if no further activity occurs. If further support is needed, please provide an update and/or more details.