microsoft / onnxruntime

ONNX Runtime: cross-platform, high performance ML inferencing and training accelerator
https://onnxruntime.ai
MIT License
14.54k stars 2.91k forks source link

Using GPU in c++ #13380

Open EmreOzkose opened 2 years ago

EmreOzkose commented 2 years ago

Hi,

Is there an example to use GPU in C++? Is it enough to add below lines to the code to use GPU?

OrtCUDAProviderOptions cudaOptions;
cudaOptions.device_id = 0;
sessionOptions.AppendExecutionProvider_CUDA(cudaOptions);

I am testing my onnx model for both cpu and gpu. GPU is x10 slower than CPU. Should I change something about allocator or add some optimization parameters (specifically for GPU device)?

Versions: onnxruntime 1.12.1 cuda 11.3 cudnn 8.6.0.163_cuda11

BitCourier commented 2 years ago

Maybe https://onnxruntime.ai/docs/performance/tune-performance.html#register-the-ep and https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html#cc can help you?

EmreOzkose commented 2 years ago

ı saw CUDA-ExecutionProvider, hence I added cuda options to session options. What is the difference between them? In these examples, a Cuda provider is created and added to session options as I did.

BitCourier commented 2 years ago

You created the cudaOptions struct without setting all the parameters mentioned in the example. Maybe one of these will change your results. The V2 options struct allows you to set some parameters, which the old struct doesn't expose.

EmreOzkose commented 2 years ago

I cannot use V2 options, somehow , it is missing, I will try to use it.

I also tried to use GPU in linux. I just added above 3 lines , and onnx model on GPU is x8 times faster than CPU. The difference is that I build it from source in Linux.

BitCourier commented 2 years ago

I cannot use V2 options, somehow , it is missing, I will try to use it.

You need to use the C API for that:

auto api = Ort::GetApi(); OrtCUDAProviderOptionsV2 cuda_options = nullptr; Ort::ThrowOnError(api.CreateCUDAProviderOptions(&cuda_options)); std::vector<const char> keys{"device_id",...}; std::vector<const char*> values{"0",..}; Ort::ThrowOnError(api.UpdateCUDAProviderOptions(cuda_options, keys.data(), values.data(), keys.size()));

And then use the "AppendExecutionProvider_CUDA_V2" method of your Ort::SessionOptions instance.

EmreOzkose commented 2 years ago

I realised that If I did warmup for all possible inputs, it works. However I am using speech data and range ıf possible input is quite large.

EmreOzkose commented 2 years ago

For example, if I did a warmup with (500, 80) shaped input, next (500, 80) shaped input will be much faster than first. But if I test with (501, 80) shaped input, it is still slow. Hence I had to do a warmup like

(50, 80), ... (1000, 80), ... (5000, 80) ...

BitCourier commented 2 years ago

I think it's because of graph optimization. In SessionOptions are some parameters to tune that. https://onnxruntime.ai/docs/performance/graph-optimizations.html Try setting "GraphOptimizationLevel::ORT_DISABLE_ALL"

EmreOzkose commented 2 years ago

I set this like sessionOptions.SetGraphOptimizationLevel(ORT_DISABLE_ALL);, but observed same behavior. I am giving a sample 5 times for warmup, first elapsed 900ms, but 2nd,3rd,4th,5th are done in 20ms. After that, when script started to give different samples with different shapes, elapsed times increases.

bencherian commented 2 years ago

For example, if I did a warmup with (500, 80) shaped input, next (500, 80) shaped input will be much faster than first. But if I test with (501, 80) shaped input, it is still slow. Hence I had to do a warmup like

(50, 80), ... (1000, 80), ... (5000, 80) ...

This is likely related to cuDNN convolution tuning. The default behavior is to exhaustively search all convolution algorithms for a given input shape size. So every time your input shape changes it will redo this search. Try using the heuristic option as well. It will likely be slightly slower than exhaustive for the second input, but there would not be any need for a warmup.

EmreOzkose commented 2 years ago

When I only changed below line, onnx model become x10 times faster. Speed is still as much as CPU, but certainly HEURISTIC works. I realised that short samples (I am using speech data) takes longer on GPU. When sample shape increases, elapsed time decreases on GPU. I think it is somehow transferring issue (from CPU to GPU or vice versa).

cudaOptions.cudnn_conv_algo_search = OrtCudnnConvAlgoSearch::HEURISTIC; //OrtCudnnConvAlgoSearch();

I will do some experiments (testing with only long samples / testing with only short samples) to ensure this observation.

HyeonseopLim commented 8 months ago

@EmreOzkose

When I only changed below line, onnx model become x10 times faster. Speed is still as much as CPU, but certainly HEURISTIC works. I realised that short samples (I am using speech data) takes longer on GPU. When sample shape increases, elapsed time decreases on GPU. I think it is somehow transferring issue (from CPU to GPU or vice versa).

cudaOptions.cudnn_conv_algo_search = OrtCudnnConvAlgoSearch::HEURISTIC; //OrtCudnnConvAlgoSearch();

I will do some experiments (testing with only long samples / testing with only short samples) to ensure this observation.

cudaOptions.cudnn_conv_algo_search = OrtCudnnConvAlgoSearch::HEURISTIC; //OrtCudnnConvAlgoSearch();

It works well. Thanks! and you said that you built ORT from the source. Does it mean that you just get the source codes and build it with CMake?

EmreOzkose commented 8 months ago

@HyeonseopLim , as far as I remember, yes.