pyutils / line_profiler

Line-by-line profiling for Python
2.77k stars 119 forks source link

PyTorch memory profiling #237

Open tmm1 opened 1 year ago

tmm1 commented 1 year ago

i have been exploring VRAM usage of pytorch code, and wanted to share my experience

i search for previous discussions related to memory profiling and customized stats collection, and found #188 #216

at first I experimented a bit with pytorch_memlab's LineProfiler. however the tool is geared more towards peak memory usage rather than annotating which line is responsible for the growth

so that I started exploring how line_profiler's machinery could be extended or re-used instead. turns out its pretty simple:

diff --git a/line_profiler/_line_profiler.pyx b/line_profiler/_line_profiler.pyx
index c9c8f32..b698111 100644
--- a/line_profiler/_line_profiler.pyx
+++ b/line_profiler/_line_profiler.pyx
@@ -5,6 +5,7 @@ This is the Cython backend used in :py:mod:`line_profiler.line_profiler`.
 from .python25 cimport PyFrameObject, PyObject, PyStringObject
 from sys import byteorder
 import sys
+import torch
 cimport cython
 from cpython.version cimport PY_VERSION_HEX
 from libc.stdint cimport int64_t
@@ -79,9 +80,13 @@ cdef extern from "Python.h":
     cdef int PyTrace_C_RETURN

 cdef extern from "timers.c":
-    PY_LONG_LONG hpTimer()
+    #PY_LONG_LONG hpTimer()
     double hpTimerUnit()

+def hpTimer():
+    return torch.cuda.memory_allocated(0)
+    #return torch.cuda.memory_reserved(0)
 cdef extern from "unset_trace.c":
     void unset_trace()

the results show you where memory is allocated and released:

  1837        24          0.0      0.0      0.0                  with self.accelerator.accumulate(model):
  1838        12 4011721216.0    3e+08   2260.0                      tr_loss_step = self.training_step(model, inputs)
  1840        48     -12288.0   -256.0     -0.0                  if (
  1841        12          0.0      0.0      0.0                      args.logging_nan_inf_filter
  1842        12          0.0      0.0      0.0                      and not is_torch_tpu_available()
  1843        24      12288.0    512.0      0.0                      and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
  1844                                                           ):
  1845                                                               # if loss is nan or inf simply add the average of previous logged losses
  1846                                                               tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)
  1847                                                           else:
  1848        12          0.0      0.0      0.0                      tr_loss += tr_loss_step
  1917        12  160466944.0    1e+07     90.4                          self.optimizer.step()
  1918        12          0.0      0.0      0.0                          optimizer_was_run = not self.accelerator.optimizer_step_was_skipped
  1920        12          0.0      0.0      0.0                      if optimizer_was_run:
  1921                                                                   # Delay optimizer scheduling until metrics are generated
  1922        12          0.0      0.0      0.0                          if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
  1923        12          0.0      0.0      0.0                              self.lr_scheduler.step()
  1925        12       -4e+09   -3e+08  -2250.4                      model.zero_grad()
Erotemic commented 1 year ago

Interesting. Thank you for taking the time to share your knowledge. I'm a heavy torch user, so this would be useful for me. I don't think this repo is the place to implement it, but a fork of this repo certainly is.

I'm wondering if there is a C-mechanism that can be used in hpTimer. It looks like the torch.cuda.memory_allocated call boils down to a call to torch._C._cuda_memoryStats which is defined here:

Probably a way to hook into torchlib to get that info more efficiently.

tmm1 commented 1 year ago

Probably a way to hook into torchlib to get that info more efficiently.

I got this working:

EDIT: Notebook showing example usage: