microsoft / onnxruntime

ONNX Runtime: cross-platform, high performance ML inferencing and training accelerator
https://onnxruntime.ai
MIT License
14.61k stars 2.92k forks source link

[Feature Request] Add torch.Tensor support for InferenceSession input_feed #20481

Open thiagocrepaldi opened 6 months ago

thiagocrepaldi commented 6 months ago

Describe the feature request

Having ONNX Runtime using torch.Tensor (in addition to the current numpy) tensors is useful for the scenarios in which numpy does not support the data type used in the original torch model, such as torch.bfloat16.

Describe scenario use case

Today, we are forced to transform the onnx graph to convert bfloat16 into float16 due to umpy's lack of support for bfloat16

Supporting torch.Tensor direcly also makes ORT closer to PyTorch's original model, without numpy as a middle man

wschin commented 6 months ago

Workarounds for running ORT with PyTorch tensors -- https://github.com/microsoft/onnxruntime/issues/20281 & https://github.com/pytorch/pytorch/blob/4f29103749c5011529f1abb10b1508a682588909/torch/onnx/_internal/onnxruntime.py#L414 (if onnxruntime-training is installed). It's a helpful thing to implement along the way but right now, there is no ETA.