Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.34k stars 3.38k forks source link

Pruning callback causes GPU memory leak when used iteratively #8542

Open MohammedAljahdali opened 3 years ago

MohammedAljahdali commented 3 years ago

Discussed in https://github.com/PyTorchLightning/pytorch-lightning/discussions/8363

Originally posted by **MohammedAljahdali** July 10, 2021 Hi, I have a script that does the following logic: ``` model = Model() dm = DataModule() callback_a = ModelPruning() for _ in range(N): callbacks = [CallbackB(), CallbackC(), callback_a] trainer = Trainer(callbacks=callbacks, ...) trainer.fit(model, dm) trainer.test() ``` There is much more going on, but to keep it simple this is the flow that I have, my question is there another preferable way to what I just did? Also, in my code I have some memory leak, that happens after each loop iteration, could this be somehow related to the trainer object, not being deleted properly?

After trying to reproduce the issue with the boring model, it turned out that the cause of the momery leak is not the reinitialization of the trainer, but the pruning callback itself.

Dependices:

  - python=3.8
  - pip
  - cudatoolkit=10.2
  - pytorch=1.8.1
  - torchvision=0.9.1
  - pytorch-lightning>=1.3.2
  - torchmetrics>=0.3.2
  - wandb>=0.10.30

Code to reproduce:

import os
import gc

import torch
from torch.utils.data import DataLoader, Dataset
from torch import nn
import wandb
from pytorch_lightning import LightningModule, Trainer, LightningDataModule
from pytorch_lightning.loggers import WandbLogger
import torchvision
import pytorch_lightning as pl
import numpy as np
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from pytorch_lightning.utilities.distributed import rank_zero_only
from collections import OrderedDict, defaultdict

class DM(LightningDataModule):

    def __init__(self, n_features, n_samples, batch_size):
        super().__init__()
        self.train_dataset = torchvision.datasets.CIFAR10(root='/tmp', download=True,
                                                          transform=torchvision.transforms.ToTensor())
        self.val_dataset = torchvision.datasets.CIFAR10(root='/tmp', transform=torchvision.transforms.ToTensor())
        self.test_dataset = torchvision.datasets.CIFAR10(root='/tmp', transform=torchvision.transforms.ToTensor())
        self.batch_size = batch_size

    def setup(self, stage=None):
        pass

    def train_dataloader(self):
        return DataLoader(
            dataset=self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True
        )

    def val_dataloader(self):
        return DataLoader(
            dataset=self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False
        )

    def test_dataloader(self):
        return DataLoader(
            dataset=self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False
        )

class BoringModel(LightningModule):

    def __init__(self, in_features):
        super().__init__()
        self.run_id = 0
        self.test_counter = 0
        self.net = torchvision.models.resnet18()
        self.net.fc.is_classifier = True
        self.loss = torch.nn.CrossEntropyLoss()

    #         self.layer = torch.nn.Sequential(
    #             torch.nn.Linear(in_features, in_features // 2),
    #             torch.nn.ReLU(),
    #             torch.nn.Linear(in_features // 2, in_features // 4),
    #             torch.nn.ReLU(),
    #             torch.nn.Linear(in_features // 4, in_features // 8),
    #             torch.nn.ReLU(),
    #             torch.nn.Linear(in_features // 8, 2),
    #         )

    def forward(self, x):
        return self.net(x)

    def training_step(self, batch, batch_idx):
        loss = self.loss(self(batch[0]), batch[1])
        self.log(f"train_loss {self.run_id}", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self.loss(self(batch[0]), batch[1])
        self.log(f"valid_loss {self.run_id}", loss)

    def test_step(self, batch, batch_idx):
        loss = self.loss(self(batch[0]), batch[1])
        self.log(f"test_loss {self.run_id}", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.net.parameters(), lr=1)

def run():
    in_features = 224
    N = 20
    dm = DM(n_samples=40, n_features=in_features, batch_size=256)
    model = BoringModel(in_features=in_features)
    loggers = [WandbLogger(project='test_runs', save_dir='/tmp')]
    pruning_callback = pl.callbacks.ModelPruning(
        apply_pruning=True, use_lottery_ticket_hypothesis=True,
        pruning_fn='l1_unstructured', use_global_unstructured=True, verbose=1, make_pruning_permanent=False,
        amount=0.5
    )

    for i in range(N):
        callbacks = [
            pl.callbacks.ModelCheckpoint(monitor=f"valid_loss {model.run_id}", save_top_k=1, save_last=True),
            pl.callbacks.EarlyStopping(monitor=f"valid_loss {model.run_id}", )
        ]
        callbacks.append(pruning_callback)
        trainer = Trainer(
            default_root_dir='/tmp',
            limit_train_batches=4,
            limit_val_batches=4,
            limit_test_batches=4,
            num_sanity_val_steps=0,
            max_epochs=5,
            logger=loggers,
            gpus=1,
            callbacks=callbacks
        )

        trainer.fit(model, datamodule=dm)
        trainer.test(model, ckpt_path='best')
        trainer.test(model)
        model.run_id += 1
        counter = 0
        for obj in gc.get_objects():
            try:
                if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
                    counter += 1
            except:
                pass
        print(f"Number of tensors after level {i - 1} is {counter}")

if __name__ == '__main__':
    run()
tchaton commented 3 years ago

Dear @MohammedAljahdali,

Mind trying out master. We should have resolved this problem. Mind confirming ?

Best, T.C

MohammedAljahdali commented 3 years ago

Dear @tchaton,

I tried pl 1.5.0.dev0 and 1.4.1, on my code base the issue still happens, although something weird occur, the number of tensors after each pruning iteration does not always increase:

Number of tensors after level 0 is 6166
Number of tensors after level 0 is 6166

Number of tensors after level 1 is 7249
Number of tensors after level 1 is 7249

Number of tensors after level 2 is 12161
Number of tensors after level 2 is 12160

Number of tensors after level 3 is 17079
Number of tensors after level 3 is 17079

Number of tensors after level 4 is 10493
Number of tensors after level 4 is 10493

Number of tensors after level 5 is 11576
Number of tensors after level 5 is 11575

Number of tensors after level 6 is 16489
Number of tensors after level 6 is 16488

Number of tensors after level 7 is 13739
Number of tensors after level 7 is 13739

Number of tensors after level 8 is 17014
Number of tensors after level 8 is 17013

Number of tensors after level 9 is 15909
Number of tensors after level 9 is 15907

I call torch.cuda.empty_cache() between prints and sometimes it remove a tensor or two.

As for when I tested pl 1.4.1 and 1.5.0 with the boring model I got the following issue:

Traceback (most recent call last):
  File "/home/aljahdmk/projects/shrinkbench/pruning_leak.py", line 123, in <module>
    run()
  File "/home/aljahdmk/projects/shrinkbench/pruning_leak.py", line 109, in run
    trainer.fit(model, datamodule=dm)
  File "/home/aljahdmk/miniconda3/envs/nnp-nas-2/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 553, in fit
    self._run(model)
  File "/home/aljahdmk/miniconda3/envs/nnp-nas-2/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 862, in _run
    self.call_hook("on_before_accelerator_backend_setup", model)
  File "/home/aljahdmk/miniconda3/envs/nnp-nas-2/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1217, in call_hook
    trainer_hook(*args, **kwargs)
  File "/home/aljahdmk/miniconda3/envs/nnp-nas-2/lib/python3.8/site-packages/pytorch_lightning/trainer/callback_hook.py", line 42, in on_before_accelerator_backend_setup
    callback.on_before_accelerator_backend_setup(self, model)
  File "/home/aljahdmk/miniconda3/envs/nnp-nas-2/lib/python3.8/site-packages/pytorch_lightning/callbacks/pruning.py", line 385, in on_before_accelerator_backend_setup
    self._original_layers.setdefault(id_, _LayerRef(data=deepcopy(module), names=[]))
  File "/home/aljahdmk/miniconda3/envs/nnp-nas-2/lib/python3.8/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/home/aljahdmk/miniconda3/envs/nnp-nas-2/lib/python3.8/copy.py", line 270, in _reconstruct
    state = deepcopy(state, memo)
  File "/home/aljahdmk/miniconda3/envs/nnp-nas-2/lib/python3.8/copy.py", line 146, in deepcopy
    y = copier(x, memo)
  File "/home/aljahdmk/miniconda3/envs/nnp-nas-2/lib/python3.8/copy.py", line 230, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/home/aljahdmk/miniconda3/envs/nnp-nas-2/lib/python3.8/copy.py", line 153, in deepcopy
    y = copier(memo)
  File "/home/aljahdmk/miniconda3/envs/nnp-nas-2/lib/python3.8/site-packages/torch/tensor.py", line 55, in __deepcopy__
    raise RuntimeError("Only Tensors created explicitly by the user "
RuntimeError: Only Tensors created explicitly by the user (graph leaves) support the deepcopy protocol at the moment

This is the boring model code:

import gc

import torch
from torch.utils.data import DataLoader

from pytorch_lightning import LightningModule, Trainer, LightningDataModule
from pytorch_lightning.loggers import WandbLogger
import torchvision
import pytorch_lightning as pl

class DM(LightningDataModule):

    def __init__(self, n_features, n_samples, batch_size):
        super().__init__()
        self.train_dataset = torchvision.datasets.CIFAR10(root='/tmp', download=True,
                                                          transform=torchvision.transforms.ToTensor())
        self.val_dataset = torchvision.datasets.CIFAR10(root='/tmp', transform=torchvision.transforms.ToTensor())
        self.test_dataset = torchvision.datasets.CIFAR10(root='/tmp', transform=torchvision.transforms.ToTensor())
        self.batch_size = batch_size

    def setup(self, stage=None):
        pass

    def train_dataloader(self):
        return DataLoader(
            dataset=self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True
        )

    def val_dataloader(self):
        return DataLoader(
            dataset=self.val_dataset,
            batch_size=1,
            shuffle=False
        )

    def test_dataloader(self):
        return DataLoader(
            dataset=self.test_dataset,
            batch_size=1,
            shuffle=False
        )

class BoringModel(LightningModule):

    def __init__(self, in_features):
        super().__init__()
        self.run_id = 0
        self.test_counter = 0
        self.net = torchvision.models.resnet18()
        self.net.fc.is_classifier = True
        self.loss = torch.nn.CrossEntropyLoss()

    def forward(self, x):
        return self.net(x)

    def training_step(self, batch, batch_idx):
        loss = self.loss(self(batch[0]), batch[1])
        self.log(f"train_loss {self.run_id}", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self.loss(self(batch[0]), batch[1])
        self.log(f"valid_loss {self.run_id}", loss)

    def test_step(self, batch, batch_idx):
        loss = self.loss(self(batch[0]), batch[1])
        self.log(f"test_loss {self.run_id}", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.net.parameters(), lr=1)

def run():
    in_features = 224
    N = 20
    dm = DM(n_samples=40, n_features=in_features, batch_size=256)
    model = BoringModel(in_features=in_features)
    loggers = [WandbLogger(project='test_runs', save_dir='/tmp')]
    pruning_callback = pl.callbacks.ModelPruning(
        apply_pruning=True, use_lottery_ticket_hypothesis=True,
        pruning_fn='l1_unstructured', use_global_unstructured=True, verbose=1, make_pruning_permanent=False,
        amount=0.2,
    )

    for i in range(N):
        callbacks = [pl.callbacks.ModelCheckpoint(monitor=f"valid_loss {model.run_id}", save_top_k=1, save_last=True),
                     pl.callbacks.EarlyStopping(monitor=f"valid_loss {model.run_id}", ), pruning_callback]
        trainer = Trainer(
            default_root_dir='/tmp',
            limit_train_batches=4,
            limit_val_batches=4,
            limit_test_batches=4,
            num_sanity_val_steps=0,
            max_epochs=5,
            logger=loggers,
            gpus=1,
            callbacks=callbacks,
            log_every_n_steps=2
        )

        trainer.fit(model, datamodule=dm)
        # trainer.test(ckpt_path='best')
        # trainer.test(model, ckpt_path=None)
        model.run_id += 1
        counter = 0
        for obj in gc.get_objects():
            try:
                if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
                    counter += 1
            except:
                pass
        print(f"Number of tensors after level {i} is {counter}")

if __name__ == '__main__':
    run()

Kindest regards,

MohammedAljahdali commented 3 years ago

Hi @tchaton,

From what I gather from the recent changes to the pruning callback, is the issue supposed to be solved by the addition of preventing deep copy of dataloaders? If yes? Then where in the pruning callback is it applied? Because I only found it used in the stochastic weight averaging.

And if you could point to me where this issue is caused or how I could help in solving it I would gladly help.

tchaton commented 3 years ago

Dear @MohammedAljahdali,

Feel free to make a PR with a fix for this behaviour. I would assign you to this ticket.

Best, T.C

tchaton commented 3 years ago

Hey @MohammedAljahdali,

Any updates on this ? I believe we could have a test to checking the tensors being leaking.

I believe it happens within those lines: https://github.com/PyTorchLightning/pytorch-lightning/blob/b3e9dff32d842431b067b0ab83e508ffe3262968/pytorch_lightning/callbacks/pruning.py#L377.

Somehow, the deepcopy are created new tensors reference and they are not being properly cleaned up after each run.

Are you still interested to investigate this ?

Best, T.C

MohammedAljahdali commented 3 years ago

Hi @tchaton,

I am still interested in this problem. But, sadly at the moment, I am very busy with school, so I can not investigate it.

tchaton commented 3 years ago

Dear @MohammedAljahdali,

Feel free to investigate whenever you have some bandwidth.

Best, T.C

JinLi711 commented 2 years ago

I investigated the issue and it seems that the error

RuntimeError: Only Tensors created explicitly by the user (graph leaves) support the deepcopy protocol at the moment

happens because we're trying to deep copy parameters that are not explicitly defined by the user but are rather existing deep copies.

So going through the pruning code, we call on_before_accelerator_backend_setup, which deep copies self._parameters_to_prune to self._original_layers. When self.on_train_epoch_end is called, that function calls self.apply_lottery_ticket_hypothesis(), which reassigns the data in self._parameters_to_prune with weights from self._original_layers, which are deep copies.

So when on_before_accelerator_backend_setup is called again, the error above is raised because we are trying to deep copy data that are already deep copies and not explicitly defined by the user.

I don't think a pull request is needed to solve this issue. I think if you reinitialize the BoringModel within the for loop and save and load the weights into that, your code should work fine.

So something like:

weights = None
for i in range(N):

        model = BoringModel(in_features=in_features)
        load_weights(model, weights)
        pruning_callback = ...

       trainer = ...
       trainer.fit()
       weights = get_weights(model)