Open ZekunZh opened 1 year ago
cc @awaelchli @four4fish @carmocca
Thanks for investigating this and providing an example code @ZekunZh!
I'm not sure we can remove self.lightning_module.cpu()
, we would need to investigate the implications. This could be an unexpected breaking change for users. We should check whether there is a different solution to this first.
I ran a couple of tests with your script, removing Lightning and only running with the raw PyTorch model:
...
for i in range(N_ITERATIONS):
torch.cuda.empty_cache()
gc.collect()
model = model.to("cuda:0")
with torch.inference_mode():
for batch in dataloader:
model(batch.to("cuda:0"))
# model.cpu()
current_memory = process.memory_info().rss
memory_usage = convert_bytes_to_megabytes(current_memory - initial_memory)
print(f"Iteration {i + 1}: Resident Memory used: {memory_usage:.3f} MB")
memory_usages.append(memory_usage)
...
(results produced with torch nightly 2.2.0.dev20230920+cu121)
While the memory increase is definitely smaller, it is still a steady slope. I suppose on a production system with thousands of requests these few MB could add up. I'm definitely not familiar with memory management in Python and PyTorch, but there seems to be some hidden state somewhere that's not just in Lightning. Perhaps the impact is just amplified with Lightning and the root cause something else.
Hey, I have the same problem but during training. I am currently using lightning in an active learning loop, in which I recreate the trainer in each loop.
Calling gc.collect()
after each iteration successfully fixes this issue.
I am using lightning version 2.2.0.post0
and torch 2.2.0
import gc
import os
from copy import deepcopy
from typing import Any
import psutil
import torch
import torch.nn.functional as F
from lightning import LightningModule, Trainer
from matplotlib import pyplot as plt
from torch import nn, optim
from torch.utils.data import TensorDataset, DataLoader
from tqdm import tqdm
class MLP(LightningModule):
def __init__(
self,
n_hidden=10240 # Quite a large value to amplify the effect, my actual model has roughly the same size
):
super().__init__()
torch.manual_seed(0)
self.model = nn.Sequential(
nn.Linear(1024, n_hidden),
nn.BatchNorm1d(n_hidden),
nn.ReLU(),
nn.Linear(n_hidden, n_hidden),
nn.BatchNorm1d(n_hidden),
nn.ReLU(),
nn.Linear(n_hidden, n_hidden),
nn.BatchNorm1d(n_hidden),
nn.ReLU(),
nn.Linear(n_hidden, 2)
)
self._initial_state = deepcopy(self.state_dict())
def forward(self, x) -> Any:
return F.log_softmax(self.model(x), dim=1)
def training_step(self, batch):
x, y = batch
pred = self.forward(x)
loss = F.nll_loss(pred, y.view(-1))
return loss
def configure_optimizers(self):
return optim.Adam(self.parameters())
def reset(self):
self.load_state_dict(self._initial_state)
def convert_bytes_to_megabytes(memory_bytes):
return memory_bytes / 1024 ** 2
def run_pytorch(model, dataloader, iterations=10):
model.train()
process = psutil.Process(os.getpid())
initial_memory = process.memory_info().rss
memory_usages = []
for _ in range(iterations):
model.reset()
optimizer = optim.Adam(model.parameters())
for x, y in tqdm(dataloader):
x = x.to("cuda")
y = y.to("cuda")
pred = model(x)
loss = F.nll_loss(pred, y)
# loss = model.training_step((x, y))
optimizer.zero_grad()
loss.backward()
optimizer.step()
current_memory = process.memory_info().rss
diff = convert_bytes_to_megabytes(current_memory - initial_memory)
memory_usages.append(diff)
return memory_usages
def run_lightning(model, dataloader, iterations=10, fix=False):
process = psutil.Process(os.getpid())
initial_memory = process.memory_info().rss
memory_usages = []
for _ in range(iterations):
model.reset()
trainer = Trainer(max_epochs=1, enable_checkpointing=False, logger=False)
trainer.fit(model, train_dataloaders=dataloader)
if fix:
gc.collect()
current_memory = process.memory_info().rss
diff = convert_bytes_to_megabytes(current_memory - initial_memory)
memory_usages.append(diff)
return memory_usages
def main():
# Create some example data
torch.manual_seed(0)
X = torch.randn(size=(64, 1024))
Y = (torch.rand(size=(64,)) < 0.2).long()
dataset = TensorDataset(X, Y)
dataloader = DataLoader(
dataset,
batch_size=64,
shuffle=True,
num_workers=1
)
model = MLP()
# Remove model from cpu memory
model.to("cuda")
torch_usage = run_pytorch(model, dataloader)
lightning_usage = run_lightning(model, dataloader)
lightning_usage_fix = run_lightning(model, dataloader, fix=True)
plt.plot(torch_usage, label="PyTorch")
plt.plot(lightning_usage, label="Lightning")
plt.plot(lightning_usage_fix, label="Lightning - Fix")
plt.legend()
plt.xlabel("Iterations")
plt.ylabel("Memory usage")
plt.show()
main()
PS: I think the memory for the fixed version is negative, because the trainer moves some stuff directly to the GPU.
Thanks for collecting more data here. So then if gc.collect()
"fixes" this, it must mean that there is nothing seriously wrong with the code in Lightning/PyTorch because references have been released. It's just Python not collecting the garbage fast enough? Is that right?
Hypothetically, if we were to insert a gc.collect()
at the beginning of Trainer.fit()
(cleaning up memory in case there was a trainer instance deleted before), would this be equivalent to your "fix"?
Yes, it seems that the result is the same.
Using the following implementation with fix=True
solves the issue.
class GCTrainer(Trainer):
def fit(self, fix: bool = False, *args, **kwargs):
if fix:
gc.collect()
super().fit(*args, **kwargs)
@carmocca What are your thoughts on adding a gc.collect()
call at the beginning of Trainer's _run()
function?
I'm learning towards not adding it. Instantiating trainers like this in a loop is very unconventional and there is a cost to triggering gc
for everybody else. We also don't understand why these are not getting freed periodically as you'd expect. Perhaps this is python version-dependent or platform-dependent.
If somebody can explain the cause of this, we would be better informed to create a fix: either by improving the reference counts or by adding this collect()
call
If somebody can explain the cause of this, we would be better informed to create a fix: either by improving the reference counts or by adding this collect() call
@carmocca Just to clarify. Above we've determined that the Trainer releases these objects. So their refcount is actually 0. It's just that the GC does not collect them from memory quick enough. By adding gc.collect()
and seeing the memory drop means the refcounts were 0, so there isn't any fix we could possibly do there. The GC is making the decisions here.
Instantiating trainers like this in a loop is very unconventional
I agree. In light of this I am also ok closing this issue. But for the same argumentation, I am also ok adding the gc.collect()
. For the users who do this looping of Trainers, there is already overhead in setup and teardown of the trainer alone, so a gc.collect()
shouldn't be noticeable IMO.
Bug description
The memory consumption (RSS memory) continues to grow when
Trainer
is instantiated multiple times during the inference.In our production environment, currently we need to instantiate a
Trainer
for each request which contains 1 image. That's why we observed the OOM issue.We understand that it's might not be the best practice to use Lighting in production, any suggestions / comments are welcome ! 😃
The following curve can be reproduced with the provided python script, running 1000 iterations.
What version are you seeing the problem on?
v2.0
How to reproduce the bug
Error messages and logs
Environment
Current environment
* CUDA: - GPU: - Tesla T4 - available: True - version: 11.7 * Lightning: - lightning: 2.0.0 - lightning-cloud: 0.5.36 - lightning-utilities: 0.8.0 - pytorch-lightning: 2.0.3 - torch: 2.0.0 - torchinfo: 1.5.3 - torchmetrics: 0.11.4 - torchvision: 0.15.1 * Packages: - absl-py: 1.4.0 - adal: 1.2.7 - addict: 2.4.0 - aiofiles: 23.1.0 - aiohttp: 3.8.4 - aiohttp-retry: 2.8.3 - aiosignal: 1.3.1 - albumentations: 1.1.0 - amqp: 5.1.1 - antlr4-python3-runtime: 4.9.3 - anyio: 3.7.0 - appdirs: 1.4.4 - argcomplete: 2.1.2 - arrow: 1.2.3 - async-timeout: 4.0.2 - asyncssh: 2.13.1 - atpublic: 4.0 - attrs: 23.1.0 - autoflake: 2.2.0 - azure-common: 1.1.28 - azure-core: 1.27.0 - azure-graphrbac: 0.61.1 - azure-mgmt-authorization: 3.0.0 - azure-mgmt-containerregistry: 10.1.0 - azure-mgmt-core: 1.4.0 - azure-mgmt-keyvault: 10.2.2 - azure-mgmt-resource: 22.0.0 - azure-mgmt-storage: 21.0.0 - azure-nspkg: 3.0.2 - azure-storage: 0.36.0 - azure-storage-blob: 1.1.0 - azure-storage-common: 1.1.0 - azure-storage-nspkg: 3.1.0 - azureml-core: 1.50.0 - backports.tempfile: 1.0 - backports.weakref: 1.0.post1 - bcrypt: 4.0.1 - beautifulsoup4: 4.12.2 - billiard: 3.6.4.0 - black: 23.3.0 - blessed: 1.20.0 - blindspin: 2.0.1 - boto3: 1.26.149 - botocore: 1.29.149 - cachetools: 5.3.1 - celery: 5.2.2 - certifi: 2023.5.7 - cffi: 1.15.1 - cfgv: 3.3.1 - charset-normalizer: 3.1.0 - chumpy: 0.71 - clearml: 1.3.2 - click: 8.0.2 - click-didyoumean: 0.3.0 - click-plugins: 1.1.1 - click-repl: 0.2.0 - clickclick: 20.10.2 - cloudpickle: 2.2.1 - cmake: 3.26.4 - colorama: 0.4.6 - configobj: 5.0.8 - connexion: 2.14.2 - contextlib2: 0.5.5 - contourpy: 1.0.7 - coverage: 7.2.5 - crayons: 0.4.0 - croniter: 1.3.15 - cryptography: 3.4.8 - cycler: 0.11.0 - cython: 0.29.33 - dacite: 1.7.0 - dateutils: 0.6.12 - decorator: 5.1.1 - deepdiff: 6.3.0 - deprecated: 1.2.14 - detectron2: 0.7+cu118 - detrex: 0.3.0 - dictdiffer: 0.9.0 - dill: 0.3.6 - diskcache: 5.6.1 - distlib: 0.3.6 - distro: 1.8.0 - dnspython: 2.3.0 - docker: 6.1.3 - docker-pycreds: 0.4.0 - dpath: 2.1.6 - dulwich: 0.21.5 - dvc: 2.46.0 - dvc-data: 0.42.3 - dvc-gs: 2.22.0 - dvc-http: 2.30.2 - dvc-objects: 0.22.0 - dvc-render: 0.5.3 - dvc-studio-client: 0.10.0 - dvc-task: 0.2.1 - einops: 0.6.1 - et-xmlfile: 1.1.0 - eventlet: 0.33.3 - fairscale: 0.4.13 - fastapi: 0.86.0 - fiftyone: 0.20.0 - fiftyone-brain: 0.11.0 - fiftyone-db: 0.4.0 - filelock: 3.12.0 - flake8: 6.0.0 - flask: 2.2.5 - flask-testing: 0.8.1 - flatten-dict: 0.4.2 - flufl.lock: 7.1.1 - fonttools: 4.39.4 - frozenlist: 1.3.3 - fsspec: 2023.5.0 - ftfy: 6.1.1 - funcy: 2.0 - furl: 2.1.3 - future: 0.18.3 - fvcore: 0.1.5.post20220506 - gcsfs: 2023.5.0 - gitdb: 4.0.10 - gitdb2: 2.0.6 - gitpython: 3.1.31 - glmlib: 1.0.0 - glob2: 0.7 - google-api-core: 1.34.0 - google-auth: 2.19.1 - google-auth-oauthlib: 1.0.0 - google-cloud-core: 2.3.2 - google-cloud-pubsub: 1.0.2 - google-cloud-storage: 1.43.0 - google-crc32c: 1.5.0 - google-resumable-media: 1.3.0 - googleapis-common-protos: 1.59.0 - gputil: 1.4.0 - grandalf: 0.8 - graphql-core: 3.2.3 - greenlet: 2.0.2 - grpc-google-iam-v1: 0.12.6 - grpcio: 1.54.2 - grpcio-status: 1.48.2 - h11: 0.14.0 - h2: 4.1.0 - hpack: 4.0.0 - httpcore: 0.17.2 - httpx: 0.24.1 - huggingface-hub: 0.15.1 - humanfriendly: 10.0 - hydra-core: 1.3.2 - hydra-zen: 0.10.0 - hypercorn: 0.14.3 - hyperframe: 6.0.1 - identify: 2.5.24 - idna: 3.4 - imageio: 2.31.0 - imgaug: 0.4.0 - inflection: 0.5.1 - iniconfig: 2.0.0 - inquirer: 3.1.3 - iopath: 0.1.9 - isodate: 0.6.1 - isort: 5.12.0 - iterative-telemetry: 0.0.8 - itsdangerous: 2.1.2 - jaraco.classes: 3.3.0 - jeepney: 0.8.0 - jinja2: 3.1.2 - jmespath: 1.0.1 - joblib: 1.2.0 - json-tricks: 3.17.0 - jsonpickle: 3.0.1 - jsonschema: 4.10.0 - kaleido: 0.2.1 - keyring: 24.2.0 - keyrings.google-artifactregistry-auth: 1.1.2 - kili: 2.120.0 - kiwisolver: 1.4.4 - knack: 0.10.1 - kombu: 5.3.0 - lazy-loader: 0.2 - lightning: 2.0.0 - lightning-cloud: 0.5.36 - lightning-utilities: 0.8.0 - lit: 16.0.5.post0 - markdown: 3.4.3 - markdown-it-py: 2.2.0 - markupsafe: 2.1.3 - matplotlib: 3.7.1 - mccabe: 0.7.0 - mdurl: 0.1.2 - mmcv: 1.4.2 - mmpose: 0.21.0 - monai: 0.9.1 - mongoengine: 0.24.2 - more-itertools: 8.8.0 - motor: 3.1.2 - mpmath: 1.3.0 - msal: 1.22.0 - msal-extensions: 1.0.0 - msrest: 0.7.1 - msrestazure: 0.6.4 - multidict: 6.0.4 - munkres: 1.1.4 - mypy-extensions: 1.0.0 - nanotime: 0.5.2 - ndg-httpsclient: 0.5.1 - ndjson: 0.3.1 - networkx: 3.1 - nibabel: 3.2.1 - nodeenv: 1.8.0 - numpy: 1.24.2 - nvidia-cublas-cu11: 11.10.3.66 - nvidia-cuda-cupti-cu11: 11.7.101 - nvidia-cuda-nvrtc-cu11: 11.7.99 - nvidia-cuda-runtime-cu11: 11.7.99 - nvidia-cudnn-cu11: 8.5.0.96 - nvidia-cufft-cu11: 10.9.0.58 - nvidia-curand-cu11: 10.2.10.91 - nvidia-cusolver-cu11: 11.4.0.1 - nvidia-cusparse-cu11: 11.7.4.91 - nvidia-nccl-cu11: 2.14.3 - nvidia-nvtx-cu11: 11.7.91 - oauthlib: 3.2.2 - omegaconf: 2.2.1 - opencv-python: 4.7.0.72 - opencv-python-headless: 4.7.0.72 - openpyxl: 3.0.7 - ordered-set: 4.1.0 - orderedmultidict: 1.0.1 - orjson: 3.9.0 - packaging: 23.0 - pandas: 2.0.2 - paramiko: 3.2.0 - pathlib2: 2.3.7.post1 - pathspec: 0.11.1 - pathtools: 0.1.2 - patool: 1.12 - pika: 1.1.0 - pillow: 9.5.0 - pip: 23.2.1 - pkginfo: 1.9.6 - platformdirs: 3.5.1 - plotly: 5.14.1 - pluggy: 1.0.0 - portalocker: 2.7.0 - pprintpp: 0.4.0 - pre-commit: 3.2.2 - priority: 2.0.0 - prompt-toolkit: 3.0.38 - protobuf: 3.20.3 - psutil: 5.9.5 - pyaescrypt: 0.4.3 - pyasn1: 0.5.0 - pyasn1-modules: 0.3.0 - pybind11: 2.11.1 - pycocotools: 2.0.6 - pycodestyle: 2.10.0 - pycparser: 2.21 - pydantic: 1.10.9 - pydicom: 2.0.0 - pydot: 1.4.2 - pyelftools: 0.27 - pyflakes: 3.0.1 - pygit2: 1.12.1 - pygments: 2.15.1 - pygtrie: 2.5.0 - pyjwt: 2.1.0 - pymongo: 4.3.3 - pympler: 1.0.1 - pynacl: 1.5.0 - pyopenssl: 21.0.0 - pyparsing: 3.0.9 - pyrsistent: 0.19.3 - pysocks: 1.7.1 - pytest: 7.2.2 - pytest-mock: 3.10.0 - python-dateutil: 2.8.2 - python-editor: 1.0.4 - python-gdcm: 3.0.21 - python-multipart: 0.0.6 - pytorch-lightning: 2.0.3 - pytz: 2023.3 - pywavelets: 1.4.1 - pyyaml: 6.0 - qudida: 0.0.4 - readchar: 4.0.5 - regex: 2023.6.3 - requests: 2.30.0 - requests-oauthlib: 1.3.1 - retrying: 1.3.4 - rich: 13.4.1 - rsa: 4.9 - ruamel.yaml: 0.17.21 - ruff: 0.0.270 - s3transfer: 0.6.1 - schema: 0.7.0 - scikit-image: 0.20.0 - scikit-learn: 1.2.2 - scipy: 1.10.1 - scmrepo: 0.2.1 - secretstorage: 3.3.3 - sentry-sdk: 1.25.1 - setproctitle: 1.3.2 - setuptools: 67.2.0 - shapely: 2.0.1 - shortuuid: 1.0.11 - shtab: 1.6.1 - six: 1.16.0 - smmap: 5.0.0 - smmap2: 3.0.1 - sniffio: 1.3.0 - sortedcontainers: 2.4.0 - soupsieve: 2.4.1 - sqltrie: 0.4.0 - sse-starlette: 0.10.3 - sseclient-py: 1.7.2 - starlette: 0.20.4 - starsessions: 1.3.0 - strawberry-graphql: 0.138.1 - submitit: 1.4.5 - sympy: 1.12 - tabulate: 0.9.0 - tenacity: 8.2.2 - tensorboard: 2.13.0 - tensorboard-data-server: 0.7.0 - termcolor: 2.3.0 - testcontainers: 3.0.0 - threadpoolctl: 3.1.0 - tifffile: 2023.4.12 - timm: 0.6.13 - toml: 0.10.2 - tomli: 2.0.1 - tomlkit: 0.11.8 - torch: 2.0.0 - torchinfo: 1.5.3 - torchmetrics: 0.11.4 - torchvision: 0.15.1 - tqdm: 4.64.0 - traitlets: 5.9.0 - triton: 2.0.0 - typeguard: 4.0.0 - typing-extensions: 4.6.3 - tzdata: 2023.3 - tzlocal: 5.0.1 - universal-analytics-python3: 1.1.1 - urllib3: 1.26.16 - uvicorn: 0.22.0 - vine: 5.0.0 - virtualenv: 20.23.0 - voluptuous: 0.13.1 - voxel51-eta: 0.8.4 - wandb: 0.15.0 - wcwidth: 0.2.6 - websocket-client: 1.5.2 - websockets: 11.0.3 - werkzeug: 2.2.3 - wheel: 0.40.0 - wrapt: 1.15.0 - wsproto: 1.2.0 - xmltodict: 0.13.0 - xtcocotools: 1.13 - yacs: 0.1.8 - yapf: 0.33.0 - yarl: 1.9.2 - zc.lockfile: 3.0.post1 * System: - OS: Linux - architecture: - 64bit - ELF - processor: x86_64 - python: 3.11.5 - release: 5.15.0-1042-gcp - version: #50~20.04.1-Ubuntu SMP Mon Sep 11 03:30:57 UTC 2023More info
The temporary solution to fix this issue is to add
gc.collect()
at the end of teardown method, while commentingself.lightning_module.cpu()
.Things that I've tried:
Only comment
self.lightning_module.cpu()
-> not work 🛑Only comment
_optimizers_to_device(self.optimizers, torch.device("cpu"))
-> not work 🛑Comment both
module to cpu
andoptimiser to cpu
-> not work 🛑Only add
gc.collect()
-> partially work 🟡Comment
_optimizers_to_device(self.optimizers, torch.device("cpu"))
+ addgc.collect()
-> partially work 🟡Comment
self.lightning_module.cpu()
+ addgc.collect()
-> work better 🟢Comment
self.lightning_module.cpu()
and_optimizers_to_device(self.optimizers, torch.device("cpu"))
+ addgc.collect()
-> Similar to the previous one 🟢cc @borda