microsoft / onnxruntime

ONNX Runtime: cross-platform, high performance ML inferencing and training accelerator
https://onnxruntime.ai
MIT License
14.66k stars 2.93k forks source link

[JAVA] Ability to construct a Tensor from a GPU memory pointer #20966

Open balenamiaa opened 5 months ago

balenamiaa commented 5 months ago

Describe the issue

The C# API:

        var dataPtr = cudaResource.GetMappedPointer();
        var ortValue = OrtValue.CreateTensorValueWithData(MemoryInfo, tensorElementType, Shape, dataPtr, 3 * width * height * sizeof(float));

It allows for creating a tensor from a raw GPU pointer, thus avoiding copying the data to CPU, which is important in cases where even 1 extra millisecond of latency matters. No such API is exposed on the JVM side of things. I can get around it currently by doing FFI with Project Panama with the c-api for the onnxruntime, but it should probably be exposed through JNI and have a Java API for it.

To reproduce

~

Urgency

No response

Platform

Windows

OS Version

~

ONNX Runtime Installation

Built from Source

ONNX Runtime Version or Commit ID

~

ONNX Runtime API

Java

Architecture

X64

Execution Provider

CUDA, TensorRT

Execution Provider Library Version

No response

Craigacp commented 5 months ago

If we expose direct creation of GPU tensors we'd have to have a data copy somewhere. I'm currently not sure where is the best place for that copy, and if the GPU tensors should be different in the type system somehow to make it clear that they don't have direct access to the memory.

Is the use case in a pipeline of models all resident on the GPU?

balenamiaa commented 5 months ago

Yup, the models are on the GPU. In my case, it's not a direction creation of GPU tensor, but rather wrapping a raw resource, a D3D11 texture, converted into a CUDA resource with CUDA's Direct3D 11 Interoperability. CreateTensorWithDataAsOrtValue is used for this. I don't think the CSharp API has any distinction between an OrtValue that is backed by a GPU buffer or a CPU buffer. In fact, calling GetTensorMutableRawData on a GPU backed OrtValue gives a System.AccessViolationException. I personally think having a distinction in the type system would be nice.

Craigacp commented 5 months ago

Ok, so if you've made the tensor elsewhere via JNI what kind of type would it be to allow wrapping in an OrtValue? Is it a bare pointer/long?

yuslepukhin commented 4 months ago

Do not forget about disposing native resources.

github-actions[bot] commented 3 months ago

This issue has been automatically marked as stale due to inactivity and will be closed in 30 days if no further activity occurs. If further support is needed, please provide an update and/or more details.