microsoft / onnxruntime

ONNX Runtime: cross-platform, high performance ML inferencing and training accelerator
https://onnxruntime.ai
MIT License
14.6k stars 2.92k forks source link

[Performance] Computation time of iteratively applying neural network in a single ONNX model using CUDA Execution Provider dominated by Memcpy #16625

Open thomas-yu-1 opened 1 year ago

thomas-yu-1 commented 1 year ago

Describe the issue

I am running inference using an onnx model exported using torch.jit.script. I attached the model file (Restormer) and the onnx file below. The model comes from https://github.com/swz30/Restormer, where I made modifications to make it script friendly. To test, my overall forward function applies the model iteratively e.g.

def forward(self, image_recon): for i in range(4): image_recon=model(image_recon)

where model is the Restormer model, and image recon is a tensor e.g. 20x6x512x352. The model itself can briefly be described as a U-net with a mixture of 2D convolutions and attention modules applied along the channel dimensions

To be clear: my onnx model is composed of a for loop which applies this neural network model 4 times successively.

What I found is that the vast majority of the runtime is dominated by MemcpyFromHost nodes in between iterations of applying the model. I'm wondering if this makes sense since I am using the CUDA Execution Provider so in principle, everything should stay on the GPU. I have also attached the profiling log below

Thank you very much in advance!

neural_network_only_4_iterations.zip

To reproduce

This is a copy of the model file (model found at the bottom of the file).

restormer_arch_ts_friendly.zip

Urgency

No response

Platform

Linux

OS Version

Ubuntu 20

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

1.12.1

ONNX Runtime API

Python

Architecture

X64

Execution Provider

CUDA

Execution Provider Library Version

11.4

Model File

Restormer_only_network_with_for_loop_four_cascade.zip

Is this a quantized model?

No

tianleiwu commented 1 year ago

Please refer to the following stable diffusion example for I/O binding (The input and output tensors shall stay in GPU memory during iteration for best performance):

https://github.com/microsoft/onnxruntime/blob/9a126c9add52286873d676b491e5432a0d3d3913/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_utils.py

https://github.com/microsoft/onnxruntime/blob/9a126c9add52286873d676b491e5432a0d3d3913/onnxruntime/python/tools/transformers/models/stable_diffusion/onnxruntime_cuda_txt2img.py

thomas-yu-1 commented 1 year ago

Hi, thanks for the tip! But I tried using IOBinding as in https://onnxruntime.ai/docs/api/python/api_summary.html, and I get the same problem whether I use it or not. From what I understood, IOBinding allows for binding the input and final output of the onnx model to a device; however, these memcpy's are due to intermediate steps in the onnx model, so should they not be bound to the GPU already?

thomas-yu-1 commented 1 year ago

So after I some more digging, I found that the cause is a normalization I am doing within the model. In a single iteration, I normalize the input by the mean/std, run it through the network, then un-normalize by the same mean/std. This is repeated 4 times in my use case.

Briefly, the pseudocode for a single iteration would be

mean,std=normalization(image) normalized_image=(image-mean)/std image_output=neural_network(normalized_image) image=(image_output)*std+mean

I found that when I run this in a loop, there is a memcpytohost right before the normalization; there is no memcpytohost if the normalization is removed.

This has a non-negligible effect on the computation time in my case (1.5sec with normalization vs. 1.1 sec without, entirely due to memcpytohost calls), so I am wondering if there is any way to avoid this? I was not expecting that this kind of normalization would incur this.

tianleiwu commented 1 year ago

It is recommended to Change logging level through session option to see which Node is not placed in GPU. Then change your script to export it as operator that is supported in CUDA. You may look at onnx and torch documents to find alternatives. Like cast to float before normalization, or try torch.nn.functional.normalize etc

thomas-yu-1 commented 1 year ago

Hi, thanks for the advice. Doing this, I pinpointed the problem is really because of the standard deviation; this memcpy only occurs when I divide by the standard deviation. I even tested with a manual calculation of the standard deviation using (Sqrt, Div, ReduceSumSquare), so all operators which are supported in CUDAExecutionProvider, but the problem is still persisting. It's not clear to me why even the standard .std() function in Torch should force everything to CPU, but even my manual calculation is resulting in this as well.