microsoft / onnxruntime

ONNX Runtime: cross-platform, high performance ML inferencing and training accelerator
https://onnxruntime.ai
MIT License
14.72k stars 2.93k forks source link

[Performance] Abnormal latencies on certain tasks and a GPU on standby. #17720

Open Manutea opened 1 year ago

Manutea commented 1 year ago

Describe the issue

Some inferences (taskId ~250, 450, 800 and 1700) cost more than others. It seems that during this time, the GPU does nothing and is on standby. I have the same problem with P100 GPUs or RTX8000s. I've tried the AlexNet or GoogleNet models.

Perhaps this is related to this discussion ? #14023

269016383-0767db81-c636-4d99-a220-2b3984ac3ec0

269276077-d08ddd40-4727-4693-90e8-c7d94f03b6e1

I also find these standbys with onnxruntime_perf_test, the command : ./onnxruntime_perf_test -I -S 1 -e cuda -r 2048 -p profile.json -s /data/model/googlenet/dynamic_batch_googlenet_opt.onnx

To reproduce

void onnx_benchmark_GPU(std::string &modelPath, std::string &inputTensorName, std::string &outputTensorName, int deviceId, int batch)
{
  std::vector<float> image(batch * 3 * 224 * 224,150);
  std::vector<int64_t> inputDims = {batch, 3, 224, 224};
  std::vector<int64_t> outputDims = {batch, 1000};

  Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "InferenceGPU");
  Ort::SessionOptions sessionOptions;

  //sessionOptions.SetGraphOptimizationLevel(ORT_DISABLE_ALL);
  sessionOptions.EnableProfiling("gpu_profile_file");
  OrtStatus* status = OrtSessionOptionsAppendExecutionProvider_CUDA(sessionOptions, deviceId);
  if (status != nullptr) {
    printf("Provider error\n");
    exit(EXIT_FAILURE);
  }

  Ort::Session session(env, modelPath.c_str(), sessionOptions);
  Ort::MemoryInfo infoCuda("Cuda", OrtAllocatorType::OrtArenaAllocator, deviceId, OrtMemTypeDefault);
  Ort::Allocator cudaAllocator(session, infoCuda);

  int num_iterations = 2048;
  Ort::IoBinding binding(session);
  std::ofstream f("mesureonnx.txt");
  for (int i = 0; i < num_iterations; i++)
  {
    auto input = cudaAllocator.GetAllocation(image.size() * sizeof(float));
    cudaMemcpy(input.get(), image.data(), sizeof(float) * image.size(), cudaMemcpyHostToDevice);

    auto startGPU = std::chrono::high_resolution_clock::now();

    // Create an OrtValue tensor backed by data on CUDA memory
    Ort::Value boundX = Ort::Value::CreateTensor(infoCuda, reinterpret_cast<float*>(input.get()), image.size(), inputDims.data(), inputDims.size());
    std::vector<float> outputData(std::accumulate(outputDims.begin(), outputDims.end(), 1, std::multiplies<int>()));
    auto output = cudaAllocator.GetAllocation(outputData.size() * sizeof(float));

    // Create an OrtValue tensor backed by data on CUDA memory
    Ort::Value boundY = Ort::Value::CreateTensor(infoCuda, reinterpret_cast<float*>(output.get()), outputData.size(), outputDims.data(), outputDims.size());
    binding.BindInput(inputTensorName.c_str(), boundX);
    binding.BindOutput(outputTensorName.c_str(), boundY);
    binding.SynchronizeInputs();

    session.Run(Ort::RunOptions(), binding);
    binding.SynchronizeOutputs();
    auto endGPU = std::chrono::high_resolution_clock::now();
    auto durationGPU = std::chrono::duration_cast<std::chrono::nanoseconds>(endGPU - startGPU);
    binding.ClearBoundInputs();
    binding.ClearBoundOutputs();

    f<<i<<" -- GPU inference duration : "<<durationGPU.count()<< "ns Debit : " << (1.0/(durationGPU.count()/1e9))*batch << std::endl;
  }                        
} 

Urgency

No response

Platform

Linux

OS Version

CentOS Linux release 7.6.1810 (Core)

ONNX Runtime Installation

Built from Source

ONNX Runtime Version or Commit ID

ONNX Runtime 1.15.0

ONNX Runtime API

C++

Architecture

X64

Execution Provider

CUDA

Execution Provider Library Version

CUDA 11.7

Model File

No response

Is this a quantized model?

Unknown

hariharans29 commented 1 year ago

Is your benchmark code just adding a loop to this test - https://github.com/microsoft/onnxruntime/blob/870b0bc305e163801b47dc4989927a2274cf1e07/onnxruntime/test/shared_lib/test_inference.cc#L1759 ?

The block of code for which duration is measured has an allocation (GetAllocation()) and 2 device synchronizations (binding.SynchronizeInputs(); and binding.SynchronizeOutputs();). Even if the allocation is for the same number of bytes each time and there is no real allocation every time (because of any underlying memory pool in the allocator), I would consider moving that out of the time measurement code block. In any case, the device synchronization(s) that you have there in order to ensure that the copy on the default stream has completed (cudaMemcpy(input.get(), image.data(), sizeof(float) * image.size(), cudaMemcpyHostToDevice);) may be contributing to the variance if the device was doing something else at that time. I think
SynchronizeInputs() is superfluous because cudaMemcpy() is in any case a blocking call and the data would have copied over to the cuda buffer by the end of that call. SychronizeOutputs() isn't really needed either as Run() should do a stream sync before returning.

Do you see such variances when the input is already on the right device and no IOBinding is used (i.e.) you supply OrtValues backed by CUDA memory via regular Run() (no IOBinding) ?

Manutea commented 1 year ago

Hello, thank you for replying.

I tried without using IOBiding.

void onnx_benchmark_GPU(std::string &modelPath, int deviceId)
{
  Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "ModelInference");

  Ort::SessionOptions options;
  options.EnableProfiling("gpu_profile_file");
  OrtStatus* status = OrtSessionOptionsAppendExecutionProvider_CUDA(options, deviceId);
  Ort::Session session(env, modelPath.c_str(), options);

  std::array<float, 1*3*224*224> input_data;
  input_data.fill(150.0f);
  std::vector<int64_t> input_dims = {1, 3, 224, 224};

  auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
  Ort::Value input_tensor = Ort::Value::CreateTensor<float>(memory_info, input_data.data(), input_data.size(), input_dims.data(), input_dims.size());

  const char* inputNames[] = {"input"};
  const char* outputNames[] = {"output"};
  std::ofstream f("mesureonnx.txt");

  for(int i=0; i<500; ++i)
  {
    auto startGPU = std::chrono::high_resolution_clock::now();
    auto output_tensors = session.Run(Ort::RunOptions{nullptr}, inputNames, &input_tensor, 1, outputNames, 1);
    auto endGPU = std::chrono::high_resolution_clock::now();
    auto durationGPU = std::chrono::duration_cast<std::chrono::nanoseconds>(endGPU - startGPU);
    f<<i<<" -- GPU inference duration : "<<durationGPU.count()<< "ns" << std::endl;
  }
}

I've also noticed here that inferences are sometimes too slow, and in the same task id for each test.

Figure_1 With perffeto Screenshot from 2023-09-28 10-04-09 And nvvp Screenshot from 2023-09-28 10-06-23

Manutea commented 1 year ago

I've also just tried it on the CPU provider. And, I also have a CPU waiting for something.

void onnx_benchmark_CPU(std::string &filePath, std::string &modelPath, std::string &inputTensorName, std::string &outputTensorName, int batch)                                                                     
{                                                                                                                                                                                                                  
  std::vector<float> image(batch * 3 * 224 * 224, 150);                                                                                                                                                            
  std::vector<int64_t> inputDims = {batch, 3, 224, 224};                                                                                                                                                           
  std::vector<int64_t> outputDims = {batch, 1000};                                                                                                                                                                 

  Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "InferenceCPU");                                                                                                                                                         
  Ort::SessionOptions sessionOptions;                                                                                                                                                     
  sessionOptions.EnableProfiling("cpu_profile_file");                                                                                                                                                              
  sessionOptions.SetIntraOpNumThreads(1);                                                                                                                                                                         
  Ort::Session session(env, modelPath.c_str(), sessionOptions);                                                                                                                                                    

  auto memoryInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);                                                                                                                              
  Ort::Value inputTensor = Ort::Value::CreateTensor<float>(memoryInfo, image.data(), image.size(), inputDims.data(), inputDims.size());                                                                            

  const char* inputNames[] = {inputTensorName.c_str()};                                                                                                                                                            
  const char* outputNames[] = {outputTensorName.c_str()};                                                                                                                                                          

  std::ofstream file(filePath, std::ios::app);                                                                                                                                                                     
  int nbIterations = 800;                                                                                                                                                                                          
  for (int i = 0; i < nbIterations; i++)                                                                                                                                                                           
  {                                                                                                                                                                                                                
    auto startCPU = std::chrono::high_resolution_clock::now();                                                                                                                                                     
    auto outputTensors = session.Run(Ort::RunOptions{nullptr}, inputNames, &inputTensor, 1, outputNames, 1);                                                                                                       
    auto endCPU = std::chrono::high_resolution_clock::now();                                                                                                                                                       
    auto durationCPU = std::chrono::duration_cast<std::chrono::nanoseconds>(endCPU - startCPU);                                                                                                                    
    std::cout<<"CPU inference duration : "<<durationCPU.count()<< " ns" <<std::endl;                                                                                                                               
    if(i>0)                                                                                                                                                                                                        
      file << batch << " " << durationCPU.count() << "\n";                                                                                                                                                         
  }                                                                                                                                                                                                                
  file.close();
}

Figure_1

Screenshot from 2023-10-03 09-06-06

github-actions[bot] commented 1 year ago

This issue has been automatically marked as stale due to inactivity and will be closed in 7 days if no further activity occurs. If further support is needed, please provide an update and/or more details.