pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
82.66k stars 22.26k forks source link

Memory not release after jit.trace/freeze #96726

Open zhuhaozhe opened 1 year ago

zhuhaozhe commented 1 year ago

🐛 Describe the bug

Can not use del and gc to release memory after trace/freeze. To track memory malloc/release

diff --git a/c10/core/impl/alloc_cpu.cpp b/c10/core/impl/alloc_cpu.cpp
index 6ca9ea10967..c4fd33ae701 100644
--- a/c10/core/impl/alloc_cpu.cpp
+++ b/c10/core/impl/alloc_cpu.cpp
@@ -6,6 +6,8 @@
 #include <c10/util/irange.h>
 #include <c10/util/numa.h>

+#include <iostream>
+
 // TODO: rename flags to C10
 C10_DEFINE_bool(
     caffe2_cpu_allocator_do_zero_fill,
@@ -94,7 +96,7 @@ void* alloc_cpu(size_t nbytes) {
   } else if (FLAGS_caffe2_cpu_allocator_do_junk_fill) {
     memset_junk(data, nbytes);
   }
-
+  std::cout << "malloc data in c++" << data << " size " << float(nbytes) / 1024 / 1024 << "MB" << std::endl;
   return data;
 }

@@ -103,6 +105,7 @@ void free_cpu(void* data) {
   _aligned_free(data);
 #else
   // NOLINTNEXTLINE(cppcoreguidelines-no-malloc)
+  std::cout << "free data " << data << std::endl;
   free(data);
 #endif
 }
import time
import psutil, os
import torch
import gc

class M(torch.nn.Module):
    def __init__(self):
          super(M, self).__init__()
          self.w1 = torch.rand(int(1e7), 10)
          print("malloc input", hex(self.w1.data_ptr()),  "size", 1e7 * 100 * 4 / 1024 / 1024, "MB")

    def forward(self, x):
        x = self.w1 + x
        return x

def run_leak():
    process = psutil.Process(os.getpid())
    print("crurent mem usage:", process.memory_info().rss / 1024/1024, "MB")
    a = M().eval()
    input = torch.zeros(int(1e7), 1)
    print("malloc input", hex(input.data_ptr()),  "size", 1e7 * 4 / 1024 / 1024, "MB")
    print("crurent mem usage:", process.memory_info().rss / 1024/1024, "MB")
    print("trace==============")
    a_trace = torch.jit.trace(a, input)

    del(input)
    print("===============================================delete input")
    time.sleep(2)
    print("crurent mem usage:", process.memory_info().rss / 1024/1024, "MB")
    gc.collect()
    print("gc=================================================")
    time.sleep(2)
    print("crurent mem usage:", process.memory_info().rss / 1024/1024, "MB")

    print("crurent mem usage:", process.memory_info().rss / 1024/1024, "MB")
    del(a_trace)
    print("===============================================delete a_trace")
    time.sleep(2)
    print("crurent mem usage:", process.memory_info().rss / 1024/1024, "MB")
    gc.collect()
    print("gc=================================================")
    time.sleep(2)
    print("crurent mem usage:", process.memory_info().rss / 1024/1024, "MB")

    del(a)
    print("===============================================delete a")
    time.sleep(2)
    print("crurent mem usage:", process.memory_info().rss / 1024/1024, "MB")
    gc.collect()
    print("gc=================================================")
    time.sleep(2)
    print("crurent mem usage:", process.memory_info().rss / 1024/1024, "MB")

if __name__ == '__main__':
    run_leak()
    print("process exit===========================")
    exit()

Output:

crurent mem usage: 282.78515625 MB
malloc data in c++0x7f4c2f8da040 size 381.47MB
malloc w1 0x7f4c2f8da040 size 3814.697265625 MB
malloc data in c++0x7f4c2d2b4040 size 38.147MB
malloc input 0x7f4c2d2b4040 size 38.14697265625 MB
crurent mem usage: 701.67578125 MB
trace==============
...
...
===============================================delete input
crurent mem usage: 713.7421875 MB
gc=================================================
crurent mem usage: 713.7421875 MB
crurent mem usage: 713.7421875 MB
free data 0x7f4c2d2b4040
===============================================delete a_trace
crurent mem usage: 675.59375 MB
gc=================================================
crurent mem usage: 675.59375 MB
===============================================delete a
crurent mem usage: 675.59375 MB
gc=================================================
crurent mem usage: 675.59375 MB
process exit===========================
free data 0x7f4c2f8da040

0x7f4c2f8da040 is released just before process exit, the delete and gc.collect do not work.

The release back trace, I can see the last decrese of ref count is from torch::jit::CompilationUnit which own the graph and w1 is a node of graph. But since I have delete all related variable in python, I do not know whether this PyObject is reachable to release the memory`.

#26 0x00007fffdf9e885e in torch::jit::Node::~Node (this=0x555559e66a50, __in_chrg=<optimized out>)
    at /home/haozhe/rebase/frameworks.ai.pytorch.private-cpu/torch/csrc/jit/ir/ir.h:820
#27 0x00007fffdf9e2628 in torch::jit::Graph::~Graph (this=0x555559e61920, __in_chrg=<optimized out>)
    at /home/haozhe/rebase/frameworks.ai.pytorch.private-cpu/torch/csrc/jit/ir/ir.cpp:2003
#28 0x00007fffdf9a1092 in std::_Sp_counted_ptr<torch::jit::Graph*, (__gnu_cxx::_Lock_policy)2>::_M_dispose (
    this=0x555559e65440) at /usr/include/c++/11/bits/shared_ptr_base.h:348
#29 0x00007fffd9f88ab6 in std::_Sp_counted_base<(__gnu_cxx::_Lock_policy)2>::_M_release (this=0x555559e65440)
    at /usr/include/c++/11/bits/shared_ptr_base.h:168
#30 0x00007fffd9f867b5 in std::__shared_count<(__gnu_cxx::_Lock_policy)2>::~__shared_count (this=0x555559e68798, 
    __in_chrg=<optimized out>) at /usr/include/c++/11/bits/shared_ptr_base.h:702
#31 0x00007fffddf4a334 in std::__shared_ptr<torch::jit::Graph, (__gnu_cxx::_Lock_policy)2>::~__shared_ptr (
    this=0x555559e68790, __in_chrg=<optimized out>) at /usr/include/c++/11/bits/shared_ptr_base.h:1149
#32 0x00007fffddf4a354 in std::shared_ptr<torch::jit::Graph>::~shared_ptr (this=0x555559e68790, 
    __in_chrg=<optimized out>) at /usr/include/c++/11/bits/shared_ptr.h:122
#33 0x00007fffdf82aa1e in torch::jit::GraphFunction::~GraphFunction (this=0x555559e68710, __in_chrg=<optimized out>)
    at /home/haozhe/rebase/frameworks.ai.pytorch.private-cpu/torch/csrc/jit/api/function_impl.h:11
#34 0x00007fffdf82aa5a in torch::jit::GraphFunction::~GraphFunction (this=0x555559e68710, __in_chrg=<optimized out>)
    at /home/haozhe/rebase/frameworks.ai.pytorch.private-cpu/torch/csrc/jit/api/function_impl.h:11
#35 0x00007fffef24c7ce in std::default_delete<torch::jit::Function>::operator() (this=0x555559e66458, 
    __ptr=0x555559e68710) at /usr/include/c++/11/bits/unique_ptr.h:85
#36 0x00007fffef23c448 in std::unique_ptr<torch::jit::Function, std::default_delete<torch::jit::Function> >::~unique_ptr (this=0x555559e66458, __in_chrg=<optimized out>) at /usr/include/c++/11/bits/unique_ptr.h:361
#37 0x00007fffef2ba6f3 in std::_Destroy<std::unique_ptr<torch::jit::Function, std::default_delete<torch::jit::Function> > > (__pointer=0x555559e66458) at /usr/include/c++/11/bits/stl_construct.h:140
#38 0x00007fffef29f424 in std::_Destroy_aux<false>::__destroy<std::unique_ptr<torch::jit::Function, std::default_delete<torch::jit::Function> >*> (__first=0x555559e66458, __last=0x555559e66460)
--Type <RET> for more, q to quit, c to continue without paging--
    at /usr/include/c++/11/bits/stl_construct.h:152
#39 0x00007fffef28555c in std::_Destroy<std::unique_ptr<torch::jit::Function, std::default_delete<torch::jit::Function> >*> (__first=0x555559e66450, __last=0x555559e66460) at /usr/include/c++/11/bits/stl_construct.h:185
#40 0x00007fffef26148f in std::_Destroy<std::unique_ptr<torch::jit::Function, std::default_delete<torch::jit::Function> >*, std::unique_ptr<torch::jit::Function, std::default_delete<torch::jit::Function> > > (__first=0x555559e66450, 
    __last=0x555559e66460) at /usr/include/c++/11/bits/alloc_traits.h:746
#41 0x00007fffef2dad91 in std::vector<std::unique_ptr<torch::jit::Function, std::default_delete<torch::jit::Function> >, std::allocator<std::unique_ptr<torch::jit::Function, std::default_delete<torch::jit::Function> > > >::~vector (
    this=0x555558bede60, __in_chrg=<optimized out>) at /usr/include/c++/11/bits/stl_vector.h:680
#42 0x00007fffef2d1d7a in torch::jit::CompilationUnit::~CompilationUnit (this=0x555558bede60, 
    __in_chrg=<optimized out>)
    at /home/haozhe/rebase/frameworks.ai.pytorch.private-cpu/torch/csrc/jit/api/compilation_unit.h:48
#43 0x00007fffef2eed48 in __gnu_cxx::new_allocator<torch::jit::CompilationUnit>::destroy<torch::jit::CompilationUnit> (
    this=0x555558bede60, __p=0x555558bede60) at /usr/include/c++/11/ext/new_allocator.h:162
#44 0x00007fffef2eebf3 in std::allocator_traits<std::allocator<torch::jit::CompilationUnit> >::destroy<torch::jit::CompilationUnit> (__a=..., __p=0x555558bede60) at /usr/include/c++/11/bits/alloc_traits.h:531
#45 0x00007fffef2ee407 in std::_Sp_counted_ptr_inplace<torch::jit::CompilationUnit, std::allocator<torch::jit::CompilationUnit>, (__gnu_cxx::_Lock_policy)2>::_M_dispose (this=0x555558bede50)
    at /usr/include/c++/11/bits/shared_ptr_base.h:528
#46 0x00007fffee809326 in std::_Sp_counted_base<(__gnu_cxx::_Lock_policy)2>::_M_release (this=0x555558bede50)
    at /usr/include/c++/11/bits/shared_ptr_base.h:168
#47 0x00007fffee803465 in std::__shared_count<(__gnu_cxx::_Lock_policy)2>::~__shared_count (this=0x7ffb98f11310, 
    __in_chrg=<optimized out>) at /usr/include/c++/11/bits/shared_ptr_base.h:702
#48 0x00007fffeed6dc84 in std::__shared_ptr<torch::jit::CompilationUnit, (__gnu_cxx::_Lock_policy)2>::~__shared_ptr (
    this=0x7ffb98f11308, __in_chrg=<optimized out>) at /usr/include/c++/11/bits/shared_ptr_base.h:1149
#49 0x00007fffeed6dcce in std::shared_ptr<torch::jit::CompilationUnit>::~shared_ptr (this=0x7ffb98f11308, 
    __in_chrg=<optimized out>) at /usr/include/c++/11/bits/shared_ptr.h:122
#50 0x00007fffef259523 in pybind11::class_<torch::jit::CompilationUnit, std::shared_ptr<torch::jit::CompilationUnit> >::dealloc (v_h=...)
    at /home/haozhe/rebase/frameworks.ai.pytorch.private-cpu/third_party/pybind11/include/pybind11/pybind11.h:1863
#51 0x00007fffee7fd46a in pybind11::detail::clear_instance (self=0x7ffb98f112f0)
    at /home/haozhe/rebase/frameworks.ai.pytorch.private-cpu/third_party/pybind11/include/pybind11/detail/class.h:424
#52 0x00007fffee7fd581 in pybind11::detail::pybind11_object_dealloc (self=0x7ffb98f112f0)
    at /home/haozhe/rebase/frameworks.ai.pytorch.private-cpu/third_party/pybind11/include/pybind11/detail/class.h:448
#53 0x0000555555663fac in _Py_Dealloc (op=<optimized out>)
    at /tmp/build/80754af9/python-split_1634043551344/work/Objects/object.c:2215
#54 _Py_DECREF () at /tmp/build/80754af9/python-split_1634043551344/work/Include/object.h:478

Versions

Collecting environment information... PyTorch version: 2.1.0a0+gitd9f822b Is debug build: True CUDA used to build PyTorch: None ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.5 LTS (x86_64) GCC version: (Ubuntu 11.1.0-1ubuntu1~20.04) 11.1.0 Clang version: 9.0.1-12 CMake version: version 3.22.1 Libc version: glibc-2.35

Python version: 3.8.12 (default, Oct 12 2021, 13:49:34) [GCC 7.5.0] (64-bit runtime) Python platform: Linux-5.15.0-60-generic-x86_64-with-glibc2.17 Is CUDA available: False CUDA runtime version: No CUDA CUDA_MODULE_LOADING set to: N/A GPU models and configuration: No CUDA Nvidia driver version: No CUDA cuDNN version: No CUDA HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True

CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Byte Order: Little Endian Address sizes: 52 bits physical, 57 bits virtual CPU(s): 128 On-line CPU(s) list: 0-127 Thread(s) per core: 2 Core(s) per socket: 32 Socket(s): 2 NUMA node(s): 2 Vendor ID: GenuineIntel CPU family: 6 Model: 106 Model name: Intel(R) Xeon(R) Platinum 8358 CPU @ 2.60GHz Stepping: 6 CPU MHz: 2600.000 CPU max MHz: 3400.0000 CPU min MHz: 800.0000 BogoMIPS: 5200.00 L1d cache: 3 MiB L1i cache: 2 MiB L2 cache: 80 MiB L3 cache: 96 MiB NUMA node0 CPU(s): 0-31,64-95 NUMA node1 CPU(s): 32-63,96-127 Vulnerability Itlb multihit: Not affected Vulnerability L1tf: Not affected Vulnerability Mds: Not affected Vulnerability Meltdown: Not affected Vulnerability Mmio stale data: Mitigation; Clear CPU buffers; SMT vulnerable Vulnerability Retbleed: Not affected Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization Vulnerability Spectre v2: Mitigation; Enhanced IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence Vulnerability Srbds: Not affected Vulnerability Tsx async abort: Not affected Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 monitor ds_cpl smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 invpcid_single ssbd mba ibrs ibpb stibp ibrs_enhanced fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb intel_pt avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local split_lock_detect wbnoinvd dtherm ida arat pln pts hwp hwp_act_window hwp_epp hwp_pkg_req avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg tme avx512_vpopcntdq la57 rdpid fsrm md_clear pconfig flush_l1d arch_capabilities

Versions of relevant libraries: [pip3] numpy==1.21.2 [pip3] torch==2.1.0a0+gitd9f822b [pip3] torchvision==0.15.0a0+135a0f9 [conda] mkl 2022.0.1 h06a4308_117
[conda] mkl-include 2022.1.0 pypi_0 pypi [conda] mkl-static 2022.1.0 pypi_0 pypi [conda] numpy 1.21.2 py38hd8d4704_0
[conda] numpy-base 1.21.2 py38h2b8c604_0
[conda] torch 1.11.0 pypi_0 pypi [conda] torchvision 0.15.0a0+135a0f9 dev_0

cc @EikanWang @jgong5 @wenzhe-nrv @sanchitintel

sanchitintel commented 1 year ago

Originally reported in #35600, but this issue's description supplements the debugging info present in related issues, so not marking it as a duplicate, but simply linking it to the original issue.

jgong5 commented 1 year ago

@sanchitintel Just FYI. I found torch.jit.freeze is holding extra memory.

sanchitintel commented 1 year ago

Hi @jgong5, can you please elaborate? Do you mean torch.jit.freeze also has a memory leak issue? Thanks!

jgong5 commented 1 year ago

Hi @jgong5, can you please elaborate? Do you mean torch.jit.freeze also has a memory leak issue? Thanks!

What I found was that memory cannot be reclaimed with gc after torch.jit.freeze call but can be reclaimed without jit.freeze.

zhuhaozhe commented 1 year ago

Hi @jgong5, can you please elaborate? Do you mean torch.jit.freeze also has a memory leak issue? Thanks!

What I found was that memory cannot be reclaimed with gc after torch.jit.freeze call but can be reclaimed without jit.freeze.

I found it is not just after torch.jit.freeze, we can reproduce it if https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/ir/constants.cpp#L61 is invoked with torch.jit.trace. Once the tensor is kept as an constant attr in the graph from here https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/ir/constants.cpp#L80, we will cannot release it.

The backtrace shows torch::jit::CompilationUnit holds the graph and then holds the tensor, but I do not have clue on how this torch::jit::CompilationUnit is holded.

zhuhaozhe commented 1 year ago

Hi, @eellison. I saw you have a PR https://github.com/pytorch/pytorch/pull/65442 which solved a circular reference between

Object -> CompilationUnit and CompilationUnit -> Graph (which owns the Constant Object)

Do you have some expert suggestions on this issue. Is there some circular reference between Tensors , Graph and CompilationUnit.

chengzeyi commented 1 year ago

Hi @jgong5, can you please elaborate? Do you mean torch.jit.freeze also has a memory leak issue? Thanks!

I have figured that this is because that when you call torch.jit.freeze, the original module gets cloned in Module::clone_impl, which calls Module::clone_method for each methods of this module and its submodules recursively. And Module::clone_method will call CompilationUnit::create_function for each of the cloned GraphFunction object and register it persistently in CompilationUnit. And that's why the constant tensors are held by the CompilationUnit

If you want to release the function registration in CompilationUnit, you can do as follow and the memory will get released:

torch.jit._state._python_cu.drop_all_functions()
jgong5 commented 1 year ago

Hi @jgong5, can you please elaborate? Do you mean torch.jit.freeze also has a memory leak issue? Thanks!

I have figured that this is because that when you call torch.jit.freeze, the original module gets cloned in Module::clone_impl, which calls Module::clone_method for each methods of this module and its submodules recursively. And Module::clone_method will call CompilationUnit::createe_function for each of the cloned GraphFunction object and register it persistently in CompilationUnit. And that's why the constant tensors are held by the CompilationUnit

If you want to release the function registration in CompilationUnit, you can do as follow and the memory will get released:

torch.jit._state._python_cu.drop_all_functions()

Thanks for the info. This function is also called in some PyTorch tests to avoid memory leakage report. This has to be used by care though since it is global, and also you would have to recreate eager-mode model to trace after calling it.

chengzeyi commented 1 year ago

Hi @jgong5, can you please elaborate? Do you mean torch.jit.freeze also has a memory leak issue? Thanks!

I have figured that this is because that when you call torch.jit.freeze, the original module gets cloned in Module::clone_impl, which calls Module::clone_method for each methods of this module and its submodules recursively. And Module::clone_method will call CompilationUnit::createe_function for each of the cloned GraphFunction object and register it persistently in CompilationUnit. And that's why the constant tensors are held by the CompilationUnit If you want to release the function registration in CompilationUnit, you can do as follow and the memory will get released:

torch.jit._state._python_cu.drop_all_functions()

Thanks for the info. This function is also called in some PyTorch tests to avoid memory leakage report. This has to be used by care though since it is global, and also you would have to recreate eager-mode model to trace after calling it.

Yes since it is a crucial global state and I can see that many aspects of the implementation detail of torch.jit depends on a global singleton torch._C.CompilationUnit instance, a possible workaround I could think up for this is to write a custom JIT PASS to manually free up all nodes of the graph instances of the module once you are sure that you will not need it definitely. There will still be some resources left in the registration but do not matter in most cases, though.

jgong5 commented 1 year ago

I can see that many aspects of the implementation detail of torch.jit depends on a global singleton torch._C.CompilationUnit instance, a possible workaround I could think up for this is to write a custom JIT PASS to manually free up all nodes of the graph instances of the module once you are sure that you will not need it definitely.

Sounds a good idea. Would you love to submit a PR for it?

chengzeyi commented 1 year ago

I can see that many aspects of the implementation detail of torch.jit depends on a global singleton torch._C.CompilationUnit instance, a possible workaround I could think up for this is to write a custom JIT PASS to manually free up all nodes of the graph instances of the module once you are sure that you will not need it definitely.

Sounds a good idea. Would you love to submit a PR for it?

I have worked out a proof of concept for this and it works. I do not free up all nodes of the graphs. Instead, I create a callback function to clean up all the registered graphs of a module when this module gets destructed. This fix is implemented as a cpp extension of PyTorch and does not require any source code change of PyTorch. Though I think it would be better integrated into PyTorch, I am currently busy working on my private AIGC inference performance acceleration framework and this fix is a little hacky. So I would demonstrate it later.

samlcharreyron commented 7 months ago

@chengzeyi would you be able to share your solution for this?

jiqing-feng commented 6 months ago

I can see that many aspects of the implementation detail of torch.jit depends on a global singleton torch._C.CompilationUnit instance, a possible workaround I could think up for this is to write a custom JIT PASS to manually free up all nodes of the graph instances of the module once you are sure that you will not need it definitely.

Sounds a good idea. Would you love to submit a PR for it?

I have worked out a proof of concept for this and it works. I do not free up all nodes of the graphs. Instead, I create a callback function to clean up all the registered graphs of a module when this module gets destructed. This fix is implemented as a cpp extension of PyTorch and does not require any source code change of PyTorch. Though I think it would be better integrated into PyTorch, I am currently busy working on my private AIGC inference performance acceleration framework and this fix is a little hacky. So I would demonstrate it later.

Hi, @chengzeyi . I met this issue too, thanks for clarifying it. Do you have any plan to fix it by upstreaming to PyTorch? If so, I would like to help : )

chengzeyi commented 6 months ago

@samlcharreyron @jiqing-feng Anyone who is interested in solving this can take a look at these: https://github.com/chengzeyi/stable-fast/blob/d62fc58db0a450853c7da718a7d87f78aeb58a6b/src/sfast/jit/utils.py#L11 https://github.com/chengzeyi/stable-fast/blob/d62fc58db0a450853c7da718a7d87f78aeb58a6b/src/sfast/csrc/jit/compilation_unit.cpp#L12