Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.19k stars 3.38k forks source link

Investigate Resident Memory Increase during Inference #18640

Open ZekunZh opened 1 year ago

ZekunZh commented 1 year ago

Bug description

The memory consumption (RSS memory) continues to grow when Trainer is instantiated multiple times during the inference.

In our production environment, currently we need to instantiate a Trainer for each request which contains 1 image. That's why we observed the OOM issue.

We understand that it's might not be the best practice to use Lighting in production, any suggestions / comments are welcome ! 😃

The following curve can be reproduced with the provided python script, running 1000 iterations.

2023-09-26T14h23m39s_memory_usage_originalStrategy

What version are you seeing the problem on?

v2.0

How to reproduce the bug

import gc
import os
import re
from datetime import datetime
from pathlib import Path

import numpy as np
import psutil
import torch
import torch.nn as nn
import torch.nn.functional as F
from lightning import LightningModule
from lightning.fabric.utilities.optimizer import _optimizers_to_device
from lightning.pytorch import Trainer
from lightning.pytorch.strategies.single_device import SingleDeviceStrategy
from matplotlib import pyplot as plt
from torch.utils.data import DataLoader, Dataset

TIME_FORMAT = "%Y-%m-%dT%Hh%Mm%Ss"

def get_time() -> str:
    """get current time and convert to specific format"""
    return datetime.utcnow().strftime(TIME_FORMAT)

class SimpleDataset(Dataset):
    def __len__(self):
        return 1000

    def __getitem__(self, idx):
        return torch.randn((1, 28, 28))

class SimpleModel(LightningModule):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.layer1 = nn.Linear(28 * 28, 512)
        self.layer2 = nn.Linear(512, 512)
        self.layer3 = nn.Linear(512, 512)
        self.layer4 = nn.Linear(512, 512)
        self.layer5 = nn.Linear(512, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        x = F.relu(self.layer3(x))
        x = F.relu(self.layer4(x))
        x = self.layer5(x)
        return x

    def predict_step(self, batch, batch_idx, dataloader_idx=None):
        return self(batch)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters())

class SingleDeviceProdStrategy(SingleDeviceStrategy):
    def teardown(self) -> None:
        _optimizers_to_device(self.optimizers, torch.device("cpu"))
        if self.lightning_module is not None:
            self.lightning_module.cpu()
        self.precision_plugin.teardown()
        assert self.accelerator is not None
        self.accelerator.teardown()
        self.checkpoint_io.teardown()
        gc.collect()

def convert_bytes_to_megabytes(memory_bytes):
    return memory_bytes / 1024 ** 2

def run_inference_and_monitor_memory(tag: str):
    dataset = SimpleDataset()
    dataloader = DataLoader(dataset, batch_size=32)
    model = SimpleModel()

    process = psutil.Process(os.getpid())
    initial_memory = process.memory_info().rss

    memory_usages = []

    N_ITERATIONS = 1000

    for i in range(N_ITERATIONS):
        strategy = SingleDeviceStrategy(device=torch.device("cuda:0"))
        # strategy = SingleDeviceProdStrategy(device=torch.device("cuda:0"))
        trainer = Trainer(strategy=strategy)
        trainer.predict(model, dataloader)
        current_memory = process.memory_info().rss
        memory_usage = convert_bytes_to_megabytes(current_memory - initial_memory)
        print(f"Iteration {i + 1}: Resident Memory used: {memory_usage:.3f} MB")
        memory_usages.append(memory_usage)

    plt.plot(range(1, N_ITERATIONS+1), memory_usages)
    plt.xlabel('Iteration')
    plt.ylabel('Resident Memory used (MB)')
    plt.title('Resident Memory Usage over Iterations')

    # Specify the y-ticks
    min_memory = min(memory_usages)
    max_memory = max(memory_usages)
    yticks = np.linspace(min_memory, max_memory, num=20)  # Increase num to increase density
    plt.yticks(yticks)

    fig_path = Path(__file__).parent / 'oom_minimal_example' / f'{get_time()}_memory_usage_{tag}.png'
    fig_path.parent.mkdir(exist_ok=True, parents=True)
    plt.savefig(fig_path)
    print(f"Saved figure to {fig_path.resolve()}")

def parse_args():
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--tag", type=str, required=True)
    return parser.parse_args()

if __name__ == "__main__":
    args = parse_args()
    run_inference_and_monitor_memory(tag=args.tag)

Error messages and logs

# Error messages and logs here please

Environment

Current environment * CUDA: - GPU: - Tesla T4 - available: True - version: 11.7 * Lightning: - lightning: 2.0.0 - lightning-cloud: 0.5.36 - lightning-utilities: 0.8.0 - pytorch-lightning: 2.0.3 - torch: 2.0.0 - torchinfo: 1.5.3 - torchmetrics: 0.11.4 - torchvision: 0.15.1 * Packages: - absl-py: 1.4.0 - adal: 1.2.7 - addict: 2.4.0 - aiofiles: 23.1.0 - aiohttp: 3.8.4 - aiohttp-retry: 2.8.3 - aiosignal: 1.3.1 - albumentations: 1.1.0 - amqp: 5.1.1 - antlr4-python3-runtime: 4.9.3 - anyio: 3.7.0 - appdirs: 1.4.4 - argcomplete: 2.1.2 - arrow: 1.2.3 - async-timeout: 4.0.2 - asyncssh: 2.13.1 - atpublic: 4.0 - attrs: 23.1.0 - autoflake: 2.2.0 - azure-common: 1.1.28 - azure-core: 1.27.0 - azure-graphrbac: 0.61.1 - azure-mgmt-authorization: 3.0.0 - azure-mgmt-containerregistry: 10.1.0 - azure-mgmt-core: 1.4.0 - azure-mgmt-keyvault: 10.2.2 - azure-mgmt-resource: 22.0.0 - azure-mgmt-storage: 21.0.0 - azure-nspkg: 3.0.2 - azure-storage: 0.36.0 - azure-storage-blob: 1.1.0 - azure-storage-common: 1.1.0 - azure-storage-nspkg: 3.1.0 - azureml-core: 1.50.0 - backports.tempfile: 1.0 - backports.weakref: 1.0.post1 - bcrypt: 4.0.1 - beautifulsoup4: 4.12.2 - billiard: 3.6.4.0 - black: 23.3.0 - blessed: 1.20.0 - blindspin: 2.0.1 - boto3: 1.26.149 - botocore: 1.29.149 - cachetools: 5.3.1 - celery: 5.2.2 - certifi: 2023.5.7 - cffi: 1.15.1 - cfgv: 3.3.1 - charset-normalizer: 3.1.0 - chumpy: 0.71 - clearml: 1.3.2 - click: 8.0.2 - click-didyoumean: 0.3.0 - click-plugins: 1.1.1 - click-repl: 0.2.0 - clickclick: 20.10.2 - cloudpickle: 2.2.1 - cmake: 3.26.4 - colorama: 0.4.6 - configobj: 5.0.8 - connexion: 2.14.2 - contextlib2: 0.5.5 - contourpy: 1.0.7 - coverage: 7.2.5 - crayons: 0.4.0 - croniter: 1.3.15 - cryptography: 3.4.8 - cycler: 0.11.0 - cython: 0.29.33 - dacite: 1.7.0 - dateutils: 0.6.12 - decorator: 5.1.1 - deepdiff: 6.3.0 - deprecated: 1.2.14 - detectron2: 0.7+cu118 - detrex: 0.3.0 - dictdiffer: 0.9.0 - dill: 0.3.6 - diskcache: 5.6.1 - distlib: 0.3.6 - distro: 1.8.0 - dnspython: 2.3.0 - docker: 6.1.3 - docker-pycreds: 0.4.0 - dpath: 2.1.6 - dulwich: 0.21.5 - dvc: 2.46.0 - dvc-data: 0.42.3 - dvc-gs: 2.22.0 - dvc-http: 2.30.2 - dvc-objects: 0.22.0 - dvc-render: 0.5.3 - dvc-studio-client: 0.10.0 - dvc-task: 0.2.1 - einops: 0.6.1 - et-xmlfile: 1.1.0 - eventlet: 0.33.3 - fairscale: 0.4.13 - fastapi: 0.86.0 - fiftyone: 0.20.0 - fiftyone-brain: 0.11.0 - fiftyone-db: 0.4.0 - filelock: 3.12.0 - flake8: 6.0.0 - flask: 2.2.5 - flask-testing: 0.8.1 - flatten-dict: 0.4.2 - flufl.lock: 7.1.1 - fonttools: 4.39.4 - frozenlist: 1.3.3 - fsspec: 2023.5.0 - ftfy: 6.1.1 - funcy: 2.0 - furl: 2.1.3 - future: 0.18.3 - fvcore: 0.1.5.post20220506 - gcsfs: 2023.5.0 - gitdb: 4.0.10 - gitdb2: 2.0.6 - gitpython: 3.1.31 - glmlib: 1.0.0 - glob2: 0.7 - google-api-core: 1.34.0 - google-auth: 2.19.1 - google-auth-oauthlib: 1.0.0 - google-cloud-core: 2.3.2 - google-cloud-pubsub: 1.0.2 - google-cloud-storage: 1.43.0 - google-crc32c: 1.5.0 - google-resumable-media: 1.3.0 - googleapis-common-protos: 1.59.0 - gputil: 1.4.0 - grandalf: 0.8 - graphql-core: 3.2.3 - greenlet: 2.0.2 - grpc-google-iam-v1: 0.12.6 - grpcio: 1.54.2 - grpcio-status: 1.48.2 - h11: 0.14.0 - h2: 4.1.0 - hpack: 4.0.0 - httpcore: 0.17.2 - httpx: 0.24.1 - huggingface-hub: 0.15.1 - humanfriendly: 10.0 - hydra-core: 1.3.2 - hydra-zen: 0.10.0 - hypercorn: 0.14.3 - hyperframe: 6.0.1 - identify: 2.5.24 - idna: 3.4 - imageio: 2.31.0 - imgaug: 0.4.0 - inflection: 0.5.1 - iniconfig: 2.0.0 - inquirer: 3.1.3 - iopath: 0.1.9 - isodate: 0.6.1 - isort: 5.12.0 - iterative-telemetry: 0.0.8 - itsdangerous: 2.1.2 - jaraco.classes: 3.3.0 - jeepney: 0.8.0 - jinja2: 3.1.2 - jmespath: 1.0.1 - joblib: 1.2.0 - json-tricks: 3.17.0 - jsonpickle: 3.0.1 - jsonschema: 4.10.0 - kaleido: 0.2.1 - keyring: 24.2.0 - keyrings.google-artifactregistry-auth: 1.1.2 - kili: 2.120.0 - kiwisolver: 1.4.4 - knack: 0.10.1 - kombu: 5.3.0 - lazy-loader: 0.2 - lightning: 2.0.0 - lightning-cloud: 0.5.36 - lightning-utilities: 0.8.0 - lit: 16.0.5.post0 - markdown: 3.4.3 - markdown-it-py: 2.2.0 - markupsafe: 2.1.3 - matplotlib: 3.7.1 - mccabe: 0.7.0 - mdurl: 0.1.2 - mmcv: 1.4.2 - mmpose: 0.21.0 - monai: 0.9.1 - mongoengine: 0.24.2 - more-itertools: 8.8.0 - motor: 3.1.2 - mpmath: 1.3.0 - msal: 1.22.0 - msal-extensions: 1.0.0 - msrest: 0.7.1 - msrestazure: 0.6.4 - multidict: 6.0.4 - munkres: 1.1.4 - mypy-extensions: 1.0.0 - nanotime: 0.5.2 - ndg-httpsclient: 0.5.1 - ndjson: 0.3.1 - networkx: 3.1 - nibabel: 3.2.1 - nodeenv: 1.8.0 - numpy: 1.24.2 - nvidia-cublas-cu11: 11.10.3.66 - nvidia-cuda-cupti-cu11: 11.7.101 - nvidia-cuda-nvrtc-cu11: 11.7.99 - nvidia-cuda-runtime-cu11: 11.7.99 - nvidia-cudnn-cu11: 8.5.0.96 - nvidia-cufft-cu11: 10.9.0.58 - nvidia-curand-cu11: 10.2.10.91 - nvidia-cusolver-cu11: 11.4.0.1 - nvidia-cusparse-cu11: 11.7.4.91 - nvidia-nccl-cu11: 2.14.3 - nvidia-nvtx-cu11: 11.7.91 - oauthlib: 3.2.2 - omegaconf: 2.2.1 - opencv-python: 4.7.0.72 - opencv-python-headless: 4.7.0.72 - openpyxl: 3.0.7 - ordered-set: 4.1.0 - orderedmultidict: 1.0.1 - orjson: 3.9.0 - packaging: 23.0 - pandas: 2.0.2 - paramiko: 3.2.0 - pathlib2: 2.3.7.post1 - pathspec: 0.11.1 - pathtools: 0.1.2 - patool: 1.12 - pika: 1.1.0 - pillow: 9.5.0 - pip: 23.2.1 - pkginfo: 1.9.6 - platformdirs: 3.5.1 - plotly: 5.14.1 - pluggy: 1.0.0 - portalocker: 2.7.0 - pprintpp: 0.4.0 - pre-commit: 3.2.2 - priority: 2.0.0 - prompt-toolkit: 3.0.38 - protobuf: 3.20.3 - psutil: 5.9.5 - pyaescrypt: 0.4.3 - pyasn1: 0.5.0 - pyasn1-modules: 0.3.0 - pybind11: 2.11.1 - pycocotools: 2.0.6 - pycodestyle: 2.10.0 - pycparser: 2.21 - pydantic: 1.10.9 - pydicom: 2.0.0 - pydot: 1.4.2 - pyelftools: 0.27 - pyflakes: 3.0.1 - pygit2: 1.12.1 - pygments: 2.15.1 - pygtrie: 2.5.0 - pyjwt: 2.1.0 - pymongo: 4.3.3 - pympler: 1.0.1 - pynacl: 1.5.0 - pyopenssl: 21.0.0 - pyparsing: 3.0.9 - pyrsistent: 0.19.3 - pysocks: 1.7.1 - pytest: 7.2.2 - pytest-mock: 3.10.0 - python-dateutil: 2.8.2 - python-editor: 1.0.4 - python-gdcm: 3.0.21 - python-multipart: 0.0.6 - pytorch-lightning: 2.0.3 - pytz: 2023.3 - pywavelets: 1.4.1 - pyyaml: 6.0 - qudida: 0.0.4 - readchar: 4.0.5 - regex: 2023.6.3 - requests: 2.30.0 - requests-oauthlib: 1.3.1 - retrying: 1.3.4 - rich: 13.4.1 - rsa: 4.9 - ruamel.yaml: 0.17.21 - ruff: 0.0.270 - s3transfer: 0.6.1 - schema: 0.7.0 - scikit-image: 0.20.0 - scikit-learn: 1.2.2 - scipy: 1.10.1 - scmrepo: 0.2.1 - secretstorage: 3.3.3 - sentry-sdk: 1.25.1 - setproctitle: 1.3.2 - setuptools: 67.2.0 - shapely: 2.0.1 - shortuuid: 1.0.11 - shtab: 1.6.1 - six: 1.16.0 - smmap: 5.0.0 - smmap2: 3.0.1 - sniffio: 1.3.0 - sortedcontainers: 2.4.0 - soupsieve: 2.4.1 - sqltrie: 0.4.0 - sse-starlette: 0.10.3 - sseclient-py: 1.7.2 - starlette: 0.20.4 - starsessions: 1.3.0 - strawberry-graphql: 0.138.1 - submitit: 1.4.5 - sympy: 1.12 - tabulate: 0.9.0 - tenacity: 8.2.2 - tensorboard: 2.13.0 - tensorboard-data-server: 0.7.0 - termcolor: 2.3.0 - testcontainers: 3.0.0 - threadpoolctl: 3.1.0 - tifffile: 2023.4.12 - timm: 0.6.13 - toml: 0.10.2 - tomli: 2.0.1 - tomlkit: 0.11.8 - torch: 2.0.0 - torchinfo: 1.5.3 - torchmetrics: 0.11.4 - torchvision: 0.15.1 - tqdm: 4.64.0 - traitlets: 5.9.0 - triton: 2.0.0 - typeguard: 4.0.0 - typing-extensions: 4.6.3 - tzdata: 2023.3 - tzlocal: 5.0.1 - universal-analytics-python3: 1.1.1 - urllib3: 1.26.16 - uvicorn: 0.22.0 - vine: 5.0.0 - virtualenv: 20.23.0 - voluptuous: 0.13.1 - voxel51-eta: 0.8.4 - wandb: 0.15.0 - wcwidth: 0.2.6 - websocket-client: 1.5.2 - websockets: 11.0.3 - werkzeug: 2.2.3 - wheel: 0.40.0 - wrapt: 1.15.0 - wsproto: 1.2.0 - xmltodict: 0.13.0 - xtcocotools: 1.13 - yacs: 0.1.8 - yapf: 0.33.0 - yarl: 1.9.2 - zc.lockfile: 3.0.post1 * System: - OS: Linux - architecture: - 64bit - ELF - processor: x86_64 - python: 3.11.5 - release: 5.15.0-1042-gcp - version: #50~20.04.1-Ubuntu SMP Mon Sep 11 03:30:57 UTC 2023

More info

The temporary solution to fix this issue is to add gc.collect() at the end of teardown method, while commenting self.lightning_module.cpu().

Things that I've tried:

cc @borda

ZekunZh commented 1 year ago

cc @awaelchli @four4fish @carmocca

awaelchli commented 1 year ago

Thanks for investigating this and providing an example code @ZekunZh!

I'm not sure we can remove self.lightning_module.cpu(), we would need to investigate the implications. This could be an unexpected breaking change for users. We should check whether there is a different solution to this first.

awaelchli commented 1 year ago

I ran a couple of tests with your script, removing Lightning and only running with the raw PyTorch model:

...
for i in range(N_ITERATIONS):
        torch.cuda.empty_cache()
        gc.collect()
        model = model.to("cuda:0")
        with torch.inference_mode():
            for batch in dataloader:
                model(batch.to("cuda:0"))
        # model.cpu()

        current_memory = process.memory_info().rss
        memory_usage = convert_bytes_to_megabytes(current_memory - initial_memory)
        print(f"Iteration {i + 1}: Resident Memory used: {memory_usage:.3f} MB")
        memory_usages.append(memory_usage)
...

2023-09-29T01h57m56s_memory_usage_raw-torch-no-move-to-cpu

(results produced with torch nightly 2.2.0.dev20230920+cu121)

While the memory increase is definitely smaller, it is still a steady slope. I suppose on a production system with thousands of requests these few MB could add up. I'm definitely not familiar with memory management in Python and PyTorch, but there seems to be some hidden state somewhere that's not just in Lightning. Perhaps the impact is just amplified with Lightning and the root cause something else.

MushroomMaula commented 7 months ago

Hey, I have the same problem but during training. I am currently using lightning in an active learning loop, in which I recreate the trainer in each loop. Calling gc.collect() after each iteration successfully fixes this issue.

I am using lightning version 2.2.0.post0 and torch 2.2.0

import gc
import os
from copy import deepcopy
from typing import Any

import psutil
import torch
import torch.nn.functional as F
from lightning import LightningModule, Trainer
from matplotlib import pyplot as plt
from torch import nn, optim
from torch.utils.data import TensorDataset, DataLoader
from tqdm import tqdm

class MLP(LightningModule):

    def __init__(
        self, 
        n_hidden=10240  # Quite a large value to amplify the effect, my actual model has roughly the same size
    ):  
        super().__init__()
        torch.manual_seed(0)
        self.model = nn.Sequential(
            nn.Linear(1024, n_hidden),
            nn.BatchNorm1d(n_hidden),
            nn.ReLU(),
            nn.Linear(n_hidden, n_hidden),
            nn.BatchNorm1d(n_hidden),
            nn.ReLU(),
            nn.Linear(n_hidden, n_hidden),
            nn.BatchNorm1d(n_hidden),
            nn.ReLU(),
            nn.Linear(n_hidden, 2)
        )

        self._initial_state = deepcopy(self.state_dict())

    def forward(self, x) -> Any:
        return F.log_softmax(self.model(x), dim=1)

    def training_step(self, batch):
        x, y = batch
        pred = self.forward(x)
        loss = F.nll_loss(pred, y.view(-1))
        return loss

    def configure_optimizers(self):
        return optim.Adam(self.parameters())

    def reset(self):
        self.load_state_dict(self._initial_state)

def convert_bytes_to_megabytes(memory_bytes):
    return memory_bytes / 1024 ** 2

def run_pytorch(model, dataloader, iterations=10):
    model.train()

    process = psutil.Process(os.getpid())
    initial_memory = process.memory_info().rss
    memory_usages = []

    for _ in range(iterations):
        model.reset()
        optimizer = optim.Adam(model.parameters())
        for x, y in tqdm(dataloader):
            x = x.to("cuda")
            y = y.to("cuda")
            pred = model(x)
            loss = F.nll_loss(pred, y)
            # loss = model.training_step((x, y))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        current_memory = process.memory_info().rss
        diff = convert_bytes_to_megabytes(current_memory - initial_memory)
        memory_usages.append(diff)

    return memory_usages

def run_lightning(model, dataloader, iterations=10, fix=False):
    process = psutil.Process(os.getpid())
    initial_memory = process.memory_info().rss
    memory_usages = []

    for _ in range(iterations):
        model.reset()
        trainer = Trainer(max_epochs=1, enable_checkpointing=False, logger=False)
        trainer.fit(model, train_dataloaders=dataloader)
        if fix:
            gc.collect()

        current_memory = process.memory_info().rss
        diff = convert_bytes_to_megabytes(current_memory - initial_memory)
        memory_usages.append(diff)

    return memory_usages

def main():
    # Create some example data
    torch.manual_seed(0)
    X = torch.randn(size=(64, 1024))
    Y = (torch.rand(size=(64,)) < 0.2).long()
    dataset = TensorDataset(X, Y)
    dataloader = DataLoader(
        dataset,
        batch_size=64,
        shuffle=True,
        num_workers=1
    )

    model = MLP()
    # Remove model from cpu memory
    model.to("cuda")

    torch_usage = run_pytorch(model, dataloader)
    lightning_usage = run_lightning(model, dataloader)
    lightning_usage_fix = run_lightning(model, dataloader, fix=True)

    plt.plot(torch_usage, label="PyTorch")
    plt.plot(lightning_usage, label="Lightning")
    plt.plot(lightning_usage_fix, label="Lightning - Fix")
    plt.legend()
    plt.xlabel("Iterations")
    plt.ylabel("Memory usage")
    plt.show()

main()

image PS: I think the memory for the fixed version is negative, because the trainer moves some stuff directly to the GPU.

awaelchli commented 7 months ago

Thanks for collecting more data here. So then if gc.collect() "fixes" this, it must mean that there is nothing seriously wrong with the code in Lightning/PyTorch because references have been released. It's just Python not collecting the garbage fast enough? Is that right?

Hypothetically, if we were to insert a gc.collect() at the beginning of Trainer.fit() (cleaning up memory in case there was a trainer instance deleted before), would this be equivalent to your "fix"?

MushroomMaula commented 7 months ago

Yes, it seems that the result is the same. Using the following implementation with fix=True solves the issue.

class GCTrainer(Trainer):

    def fit(self, fix: bool = False, *args, **kwargs):
        if fix:
            gc.collect()
        super().fit(*args, **kwargs)
awaelchli commented 7 months ago

@carmocca What are your thoughts on adding a gc.collect() call at the beginning of Trainer's _run() function?

carmocca commented 7 months ago

I'm learning towards not adding it. Instantiating trainers like this in a loop is very unconventional and there is a cost to triggering gc for everybody else. We also don't understand why these are not getting freed periodically as you'd expect. Perhaps this is python version-dependent or platform-dependent.

If somebody can explain the cause of this, we would be better informed to create a fix: either by improving the reference counts or by adding this collect() call

awaelchli commented 7 months ago

If somebody can explain the cause of this, we would be better informed to create a fix: either by improving the reference counts or by adding this collect() call

@carmocca Just to clarify. Above we've determined that the Trainer releases these objects. So their refcount is actually 0. It's just that the GC does not collect them from memory quick enough. By adding gc.collect() and seeing the memory drop means the refcounts were 0, so there isn't any fix we could possibly do there. The GC is making the decisions here.

Instantiating trainers like this in a loop is very unconventional

I agree. In light of this I am also ok closing this issue. But for the same argumentation, I am also ok adding the gc.collect(). For the users who do this looping of Trainers, there is already overhead in setup and teardown of the trainer alone, so a gc.collect() shouldn't be noticeable IMO.