microsoft / onnxruntime

ONNX Runtime: cross-platform, high performance ML inferencing and training accelerator
https://onnxruntime.ai
MIT License
14.65k stars 2.93k forks source link

BatchNorm fails on CUDA EP with zero length sequences #10128

Open david-macleod opened 2 years ago

david-macleod commented 2 years ago

Describe the bug When passing tensors with a dimension of zero size e.g. (8, 1024, 0) to BatchNorm1d we hit the following error

onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Non-zero status code returned while running BatchNormalization node. Name:'BatchNormalization_0' Status Message: CUDNN error executing cudnnSetTensorNdDescriptor(tensor_, dataType, static_cast<int>(rank), dims.data(), strides.data())

This is not an issue for the CPU EP and should be supported according to the ONNX spec

Thank you

System information

To Reproduce

import torch
import onnxruntime as ort
import tempfile

class Model(torch.nn.Module):

    def __init__(self) -> None:
        super().__init__()
        self.bnorm = torch.nn.BatchNorm1d(2048)

    def forward(self, x):
        x = self.bnorm(x)
        return x

x =  torch.randn(1, 2048, 0)
model = torch.jit.script(Model())
model.eval()

with tempfile.TemporaryDirectory() as temp_dir:
    temp_onnx = temp_dir + "tmp.onnx"
    torch.onnx.export(model, x, temp_onnx, opset_version=14, input_names=["x"], dynamic_axes={"x":[2]}, example_outputs=x)

    options = ort.SessionOptions()
    options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL

    for providers in (['CPUExecutionProvider'], ["CUDAExecutionProvider"]):
        session = ort.InferenceSession(temp_onnx, options, providers=providers)
        print("EPs", session.get_providers())
        output = session.run(None, input_feed={"x": x.numpy()})[0]
        print("Output", output.shape)

Expected behavior A successful inference pass, as demonstrated with the CPU EP

hariharans29 commented 2 years ago

The op is missing such a check and logic to exit early: https://github.com/microsoft/onnxruntime/blob/a367f0664d831d4fb2557e8d63fc09d67d7386fe/onnxruntime/core/providers/cuda/nn/conv.cc#L180.

Are you encountering this in a real world model and would you like to contribute this fix ?

stale[bot] commented 2 years ago

This issue has been automatically marked as stale due to inactivity and will be closed in 7 days if no further activity occurs. If further support is needed, please provide an update and/or more details.