microsoft / onnxruntime

ONNX Runtime: cross-platform, high performance ML inferencing and training accelerator
https://onnxruntime.ai
MIT License
14.24k stars 2.87k forks source link

C++ API, Memory Leak instantiating Ort::Sessions #17451

Open massimiliano96 opened 1 year ago

massimiliano96 commented 1 year ago
#include <onnxruntime_cxx_api.h>

int main()
{
    for (size_t i = 0; i < 10; i++)
    {
        std::string modelPath = std::string("./model/model.onnx");
        Ort::Env env;
        Ort::Session session = Ort::Session(env, modelPath.c_str(), Ort::SessionOptions {});
    }
    return 0;
}

This is a simple code snippet that I'm testing, I noticed that allocated memory increase at each iteration. I would expect that after each iteration the session object goes out of scope and the memory is freed, instead at each iteration the allocated memory increase by almost 150 MB.

I'm running this code on Ubuntu 22.04 and the library version is 1.15.1

How can I manage it?

skottmckay commented 1 year ago

What happens if you move the env out of the loop? That's intended to only be created once.

massimiliano96 commented 1 year ago

Thanks for the quick reply :)

I'm getting the same result. But I was inaccurate, the memory consumption does not grow at each iteration, but it grows only the first 3/4 iterations until it reaches 400 MB. Then it remains stable for the other iterations. Considering that model.onnx weights 12 MB, 400MB is a bit much, isn't it?

tianleiwu commented 1 year ago

You can run debug build, it will print memory leak if there is any. To investigate, you can add logging for memory allocation like this. It prints each memory allocation and free, then you can find out which allocation is not released at a time.

skottmckay commented 1 year ago

I tried a debug build on Windows and don't see any issues. The memory growth from creating the session matches what is freed at the end of the loop.

image

I moved the Env out of the loop as that is the expected usage and loaded the mnist.onnx model from the test data.

image
massimiliano96 commented 1 year ago
#include <onnxruntime_cxx_api.h>

int main()
{
    Ort::Env env;
    for (size_t i = 0; i < 20; i++)
    {
        std::string modelPath = std::string("./model/mnist.onnx");
        Ort::Session session = Ort::Session(env, modelPath.c_str(), Ort::SessionOptions {});
    }
    return 0;
}

I tried the same model of your test, by the way on Ubuntu the memory consumption continue to grow:

i1 i2 i3

In three iterations it reaches 52 MB

yuslepukhin commented 1 year ago

Try SessionOptions::DisableCpuArena() and see if it make a diff.

skottmckay commented 1 year ago

FWIW the memory arena should only affect the second Run call when enabled. First call we trace memory allocations required during model execution - so small allocations from the arena. Second call we allocate a single large block based on the tracing - so one large allocation from the arena. That one large allocation may require a new chunk from the arena.

However in this scenario there is only a single Run call per Session (not the expected usage in the real world), and we're not sharing an allocator across sessions, so I wouldn't expect the memory arena to be a factor.

massimiliano96 commented 1 year ago

Yes, I confirm that I'm getting the same result even with that option

pranavsharma commented 1 year ago

I tried this on my Linux machine and I see the memory stabilizing after some iterations. Here's a sample program that prints the usage.

// Copyright(c) Microsoft Corporation.All rights reserved.
// Licensed under the MIT License.
//

#include <bits/stdc++.h>
#include <onnxruntime_cxx_api.h>
#include <sys/times.h>
#include <sys/resource.h>
#include <unistd.h>

std::size_t GetPeakWorkingSetSize() {
  struct rusage rusage;
  getrusage(RUSAGE_SELF, &rusage);
  return static_cast<size_t>(rusage.ru_maxrss * 1024L);
}

int main(int argc, char* argv[]) {
  Ort::Env env;
  for (size_t i = 0; i < 100; i++)
  {
    {
      std::string modelPath = std::string("mnist.onnx");
      Ort::Session session = Ort::Session(env, modelPath.c_str(), Ort::SessionOptions {});
    }
    std::cout << "After iteration " << i << " " << GetPeakWorkingSetSize() << "\n";
    sleep(1);
  }
  std::cout << "Done\n";
  return 0;
}
claeyzre commented 8 months ago

Any updates on this ? I am facing a similar issue

memorywave commented 8 months ago

Can you try the followings? We met similar memory leak issues using onnxruntime_go that calls the onnxruntime C api. The issue happens when we create and destroy sessions repeatedly in the background, and at the same time, we do model inference. After making these change, the memory usage stablize and remain relatively low.

  1. configure session options options.enable_cpu_mem_arena = False options.enable_mem_pattern = False options.intra_op_num_threads = 1

  2. use tcmalloc instead of the standard allocator from libc used by onnxruntime library in linux

    # install tcmalloc
    apt-get -y install google-perftools
    # set LD_PRELOAD to use tcmalloc instead of the standard malloc implementation
    export LD_PRELOAD=/usr/lib/$(uname -m)-linux-gnu/libtcmalloc.so.4:${LD_PRELOAD}
    # then run your program

The second change was learnt from this release note (Section Known Issues) https://docs.nvidia.com/deeplearning/triton-inference-server/release-notes/rel-23-12.html#rel-23-12

Some systems which implement malloc() may not release memory back to the operating system right away causing a false memory leak. This can be mitigated by using a different malloc implementation. Tcmalloc and jemalloc are installed in the Triton container and can be used by specifying the library in LD_PRELOAD. We recommend experimenting with both tcmalloc and jemalloc to determine which one works better for your use case.

jindameias commented 8 months ago

Can you try the followings? We met similar memory leak issues using onnxruntime_go that calls the onnxruntime C api. The issue happens when we create and destroy sessions repeatedly in the background, and at the same time, we do model inference. After making these change, the memory usage stablize and remain relatively low.你能试试以下方法吗?我们在使用调用 onnxruntime C API 的onnxruntime_go时遇到了类似的内存泄漏问题。当我们在后台重复创建和销毁会话时,就会发生问题,同时进行模型推理。进行这些更改后,内存使用率会稳定下来,并保持相对较低的水平。

  1. configure session options配置会话选项 options.enable_cpu_mem_arena = Falseoptions.enable_cpu_mem_arena = 假 options.enable_mem_pattern = Falseoptions.enable_mem_pattern = 假 options.intra_op_num_threads = 1
  2. use tcmalloc instead of the standard allocator from libc used by onnxruntime library in linux使用 tcmalloc 代替 Linux 中 onnxruntime 库使用的 libc 中的标准分配器
# install tcmalloc
apt-get -y install google-perftools
# set LD_PRELOAD to use tcmalloc instead of the standard malloc implementation
export LD_PRELOAD=/usr/lib/$(uname -m)-linux-gnu/libtcmalloc.so.4:${LD_PRELOAD}
# then run your program

The second change was learnt from this release note (Section Known Issues)第二个更改是从本发行说明(“已知问题”部分)中了解到的 https://docs.nvidia.com/deeplearning/triton-inference-server/release-notes/rel-23-12.html#rel-23-12

Some systems which implement malloc() may not release memory back to the operating system right away causing a false memory leak. This can be mitigated by using a different malloc implementation. Tcmalloc and jemalloc are installed in the Triton container and can be used by specifying the library in LD_PRELOAD. We recommend experimenting with both tcmalloc and jemalloc to determine which one works better for your use case.某些实现 malloc() 的系统可能不会立即将内存释放回操作系统,从而导致错误的内存泄漏。这可以通过使用不同的 malloc 实现来缓解。Tcmalloc 和 jemalloc 安装在 Triton 容器中,可以通过在 LD_PRELOAD 中指定库来使用。我们建议同时使用 tcmalloc 和 jemalloc 进行试验,以确定哪一个更适合您的用例。

thank you sir, let me have a try.

cuiyongbo commented 3 months ago

Can you try the followings? We met similar memory leak issues using onnxruntime_go that calls the onnxruntime C api. The issue happens when we create and destroy sessions repeatedly in the background, and at the same time, we do model inference. After making these change, the memory usage stablize and remain relatively low.

  1. configure session options options.enable_cpu_mem_arena = False options.enable_mem_pattern = False options.intra_op_num_threads = 1
  2. use tcmalloc instead of the standard allocator from libc used by onnxruntime library in linux
# install tcmalloc
apt-get -y install google-perftools
# set LD_PRELOAD to use tcmalloc instead of the standard malloc implementation
export LD_PRELOAD=/usr/lib/$(uname -m)-linux-gnu/libtcmalloc.so.4:${LD_PRELOAD}
# then run your program

The second change was learnt from this release note (Section Known Issues) https://docs.nvidia.com/deeplearning/triton-inference-server/release-notes/rel-23-12.html#rel-23-12

Some systems which implement malloc() may not release memory back to the operating system right away causing a false memory leak. This can be mitigated by using a different malloc implementation. Tcmalloc and jemalloc are installed in the Triton container and can be used by specifying the library in LD_PRELOAD. We recommend experimenting with both tcmalloc and jemalloc to determine which one works better for your use case.

tried, but the issue still persists