Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.24k stars 3.38k forks source link

Lightning is very slow between epochs, compared to PyTorch. #10389

Closed TheMrZZ closed 1 year ago

TheMrZZ commented 2 years ago

I converted some Pytorch code to Lightning. The dataset is loaded lazily by the train & eval dataloaders.

However, when moving the code to Lightning, I noticed a huge slowdown. After digging around, I noticed that there was a ~10 seconds delay between each epoch. For comparison, on my vanilla Pytorch, an epoch takes ~4s.

I first thought it was a data loading problem, but during the 10s delay, no data is loaded (at least that's what my print tell me).

I think the issue is related to the number of workers, because setting n_workers=0 solves the problem (but is slower in the end, since only one worker is not enough). I know starting workers is slow, however I have persistent_workers=True and this does not happen in normal Pytorch. My data loaders also have pin_memory=True (removing pin_memory does not solve the problem).

Since this is company code, I cannot disclose the before/after, but I'll try to "anonymize" some code if necessary. Here is the lightning module:

class RawModule(pl.LightningModule):
    def __init__(self):
        super(RawModule, self).__init__()

        self.encoder1 = nn.Sequential(...)
        self.encoder2 = nn.Sequential(...)

    def forward(self, data1, data2):
        result1 = self.encoder1(data1)
        result2 = self.encoder2(data2)

        result1 = result1 .view(result1 .size(0), -1)
        result2 = result2 .view(result2 .size(0), -1)

        result1 = F.normalize(result1 , p=2, dim=1)
        result2 = F.normalize(result2 , p=2, dim=1)

        return result1, result2

    def calculate_loss(self, batch):
        x, r, y = batch
        a, v = self.forward(r, x)

        d = nn.functional.cosine_similarity(a, v)
        loss = logloss(d.unsqueeze(1), y)

        return loss

class Module(RawModule):
    def training_step(self, batch, batch_idx):
        loss = self.calculate_loss(batch)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        loss = self.calculate_loss(batch)
        self.log("validation_loss", loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-5)
        return optimizer

if __name__ == '__main__':
    # stuff...

    train_loader = data_utils.DataLoader(
        train_dataset, batch_size=256, shuffle=True,
        num_workers=5, persistent_workers=True,
        pin_memory=True,
    )

    val_loader = data_utils.DataLoader(
        test_dataset, batch_size=256,
        num_workers=2, persistent_workers=True,
        pin_memory=True,
    )

    # Model
    load_from_pytorch = True

    if checkpoint_path is None:
        model = Module()

        if load_from_pytorch:
            if not checkpoint_path:
                raise ValueError("Please provide a checkpoint path")
            model.load_state_dict(torch.load(checkpoint_path)['state_dict'])
    else:
        model = Module.load_from_checkpoint(checkpoint_path)

    trainer = pl.Trainer(
        gpus=1,
        max_epochs=5,
        check_val_every_n_epoch=10,
        log_every_n_steps=5,
    )
    trainer.fit(model, train_loader, val_loader)

Here is the result of profiler="simple":

Action                                  |  Mean duration (s)    |Num calls              |  Total time (s)       |  Percentage %         |
----------------------------------------------------------------------------------------------------------------------------------------
Total                                   |  -                    |_                      |  48.813               |  100 %                |
----------------------------------------------------------------------------------------------------------------------------------------
run_training_epoch                      |  27.922               |1                      |  27.922               |  57.202               |
fetch_next_sanity_check_batch           |  4.4013               |3                      |  13.204               |  27.05                |
get_sanity_check_batch                  |  4.4013               |3                      |  13.204               |  27.05                |
fetch_next_train_batch                  |  1.2734               |10                     |  12.734               |  26.087               |
get_train_batch                         |  1.2734               |10                     |  12.734               |  26.087               |
run_training_batch                      |  0.47733              |9                      |  4.296                |  8.8009               |
optimizer_step_with_closure_0           |  0.40089              |9                      |  3.608                |  7.3915               |
validation_step                         |  0.664                |2                      |  1.328                |  2.7206               |
evaluation_step_and_end                 |  0.664                |2                      |  1.328                |  2.7206               |
training_step_and_backward              |  0.12644              |9                      |  1.138                |  2.3313               |
backward                                |  0.096889             |9                      |  0.872                |  1.7864               |
training_step                           |  0.029556             |9                      |  0.266                |  0.54494              |
model_forward                           |  0.029556             |9                      |  0.266                |  0.54494              |
on_train_start                          |  0.016                |1                      |  0.016                |  0.032778             |

Here is the result of profiler="advanced": https://pastebin.com/q3C5P826.

Finally, here is a video demonstrating the problem. I'm printing each piece of data loading, to prove it's not the issue. https://user-images.githubusercontent.com/30944236/140587623-ae184fa3-370a-42be-8593-200026d11ba4.mp4

Random informations:

cc @tchaton @rohitgr7 @borda @akihironitta

TheMrZZ commented 2 years ago

After a lot of digging around, I managed to pin down the line causing the problem.

It's the line 142 in loops/epoch/training_epoch_loop.py :

class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
    ...
    def on_run_start(self, data_fetcher: AbstractDataFetcher, **kwargs: Any) -> None:
        # hook
        self.trainer.logger_connector.on_epoch_start()
        self.trainer.call_hook("on_epoch_start")
        self.trainer.call_hook("on_train_epoch_start")
        self.trainer.fit_loop.epoch_progress.increment_started()

        self._reload_dataloader_state_dict(data_fetcher)
-->     self._dataloader_iter = _update_dataloader_iter(data_fetcher, self.batch_idx + 1)

Therefore, the culprit is:

def _update_dataloader_iter(data_fetcher: AbstractDataFetcher, batch_idx: int) -> Iterator:
    """Attach the dataloader."""
    if not isinstance(data_fetcher, DataLoaderIterDataFetcher):
        # restore iteration
        dataloader_iter = enumerate(data_fetcher, batch_idx)
    else:
        dataloader_iter = iter(data_fetcher)
    return dataloader_iter

The on_run_start hook is called from loops/base.py :

class Loop(ABC, Generic[T]):
    ...
    def run(self, *args: Any, **kwargs: Any) -> T:
        if self.skip:
            return self.on_skip()

        self.reset()

-->     self.on_run_start(*args, **kwargs)
        ...

And this run method is called from loops/fit_loop.py :

class FitLoop(Loop):
    ...
    def advance(self) -> None:
        """Runs one whole epoch."""
        dataloader = self.trainer.training_type_plugin.process_dataloader(self.trainer.train_dataloader)
        data_fetcher = self.trainer._data_connector.get_profiled_dataloader(dataloader)

        with self.trainer.profiler.profile("run_training_epoch"):
-->         self.epoch_loop.run(data_fetcher)
        ...

The problem is the data_fetcher. It's indeed related to dataloaders, as expected. I still don't know what is the root reason, but I'll try to find it.

a-kore commented 2 years ago

Just a guess, but maybe having a different number of workers between the training and validation step so it may be spinning up news workers and getting rid of them between epochs?

Try making the number of workers equal for both train and val dataloaders (5-6 based on your CPU).

TheMrZZ commented 2 years ago

Just a guess, but maybe having a different number of workers between the training and validation step so it may be spinning up news workers and getting rid of them between epochs?

Try making the number of workers equal for both train and val dataloaders (5-6 based on your CPU).

I just tried this, it does not solve the problem. It would have been weird IMO, since it does not cause any problem with vanilla Pytorch.

a-kore commented 2 years ago

I see, that's interesting. Like you said, It does seem to be a dataloading issue. Maybe try removing

        check_val_every_n_epoch=10,
        log_every_n_steps=5

from the trainer call and explicitly set "reload_dataloaders_every_epoch=False" and see what happens.

Other than that, i'd try a fresh install of pytorch-lightning in a new venv.

TheMrZZ commented 2 years ago

I see, that's interesting. Like you said, It does seem to be a dataloading issue. Maybe try removing

        check_val_every_n_epoch=10,
        log_every_n_steps=5

from the trainer call and explicitly set "reload_dataloaders_every_epoch=False" and see what happens.

Other than that, i'd try a fresh install of pytorch-lightning in a new venv.

I just tried that, but it has no effect. The Trainer's methods _reset_eval_dataloader and reset_train_dataloader are never called (except once at the beginning), so it doesn't look like it's Lightning manually resetting the data loaders.

I also tried a fresh conda environment, but that didn't work either. I'm still trying to understand how the data loaders are getting reset, but I can't find anything really interesting yet.

marcm-ml commented 2 years ago

I had a similar observation where data_fetcher caused unusual long run times. For me it was indeed fixed by completely disabling multiprocess dataloading (num_workers=0). Although, I have not tried to set "reload_dataloaders_every_epoch=False“. Interesting to see that others have the same issue/observation. Funnily, setting num_workers=0 has led me to open #10182. Perhaps, there is something more to this?

TheMrZZ commented 2 years ago

TL;DR: I just commented the self.reset() line of AbstractDataLoader, located at line 198 of pytorch_lightning/utilities/fetching.py. My code runs ~20x faster. It probably isn't a correct way to fix things, but my trainings work as well as they did before.

After a bunch of fiddling around, I decided to create a custom DataLoader and overload the __iter__ method. I discovered the problem was that the _iterator property of the DataLoader was always set to None somewhere between epochs. When _iterator is None, the DataLoader is reseted and needs to start everything from scratch.

# Original DataLoader
class DataLoader(Generic[T_co]):
    def __iter__(self):
        if self.persistent_workers and self.num_workers > 0:
            if self._iterator is None:
                self._iterator = self._get_iterator()
            else:
                self._iterator._reset(self)
            return self._iterator
        else:
            return self._get_iterator()

# Custom DataLoader
class CustomDataLoader(DataLoader)
    def __iter__(self) -> '_BaseDataLoaderIter':
        print(f'\n> DataLoader __iter__ with {self._iterator=} starting.\n')
        return super().__iter__()

As you can see in the normal DataLoader, having self._iterator set to None causes a call to self._get_iterator(), which relaoads everything.

I decided to override _iterator with a custom property (getter & setter), to print the stack trace when ._iterator is set to None:

class CustomDataLoader(DataLoader)
    ...
   @property
    def _iterator(self):
        return self.__iterator

    @_iterator.setter
    def _iterator(self, value):
        if value is None:
            print('\nSetting __iterator to None. Stack trace:')
            import traceback
            traceback.print_stack()
        self.__iterator = value
        return self.__iterator

(I could also use the debugger for this)

This leads to 2 different yet very similar stack traces (respectively, evaluation & training loaders):

  File "env\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 1370, in _run_sanity_check
    self._evaluation_loop.run()
  File "env\lib\site-packages\pytorch_lightning\loops\base.py", line 144, in run
    self.advance(*args, **kwargs)
  File "env\lib\site-packages\pytorch_lightning\loops\dataloader\evaluation_loop.py", line 109, in advance
    dl_outputs = self.epoch_loop.run(dataloader, dataloader_idx, dl_max_batches, self.num_dataloaders)
  File "env\lib\site-packages\pytorch_lightning\loops\base.py", line 139, in run
    self.on_run_start(*args, **kwargs)
  File "env\lib\site-packages\pytorch_lightning\loops\epoch\evaluation_epoch_loop.py", line 87, in on_run_start
    self._dataloader_iter = _update_dataloader_iter(data_fetcher, self.batch_progress.current.ready)
  File "env\lib\site-packages\pytorch_lightning\loops\utilities.py", line 121, in _update_dataloader_iter
    dataloader_iter = enumerate(data_fetcher, batch_idx)
  File "env\lib\site-packages\pytorch_lightning\utilities\fetching.py", line 198, in __iter__
    self.reset()
  File "env\lib\site-packages\pytorch_lightning\utilities\fetching.py", line 214, in reset
    CombinedLoader._shutdown_workers_and_reset_iterator(self.dataloader)
  File "env\lib\site-packages\pytorch_lightning\trainer\supporters.py", line 498, in _shutdown_workers_and_reset_iterator
    dataloader._iterator = None
  File "env\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 1314, in _run_train
    self.fit_loop.run()
  File "env\lib\site-packages\pytorch_lightning\loops\base.py", line 144, in run
    self.advance(*args, **kwargs)
  File "env\lib\site-packages\pytorch_lightning\loops\fit_loop.py", line 234, in advance
    self.epoch_loop.run(data_fetcher)
  File "env\lib\site-packages\pytorch_lightning\loops\base.py", line 139, in run
    self.on_run_start(*args, **kwargs)
  File "env\lib\site-packages\pytorch_lightning\loops\epoch\training_epoch_loop.py", line 142, in on_run_start
    self._dataloader_iter = _update_dataloader_iter(data_fetcher, self.batch_idx + 1)
  File "env\lib\site-packages\pytorch_lightning\loops\utilities.py", line 121, in _update_dataloader_iter
    dataloader_iter = enumerate(data_fetcher, batch_idx)
  File "env\lib\site-packages\pytorch_lightning\utilities\fetching.py", line 198, in __iter__
    self.reset()
  File "env\lib\site-packages\pytorch_lightning\utilities\fetching.py", line 212, in reset
    self.dataloader.reset()
  File "env\lib\site-packages\pytorch_lightning\trainer\supporters.py", line 504, in reset
    apply_to_collection(self.loaders, DataLoader, self._shutdown_workers_and_reset_iterator)
  File "env\lib\site-packages\pytorch_lightning\utilities\apply_func.py", line 92, in apply_to_collection
    return function(data, *args, **kwargs)
  File "env\lib\site-packages\pytorch_lightning\trainer\supporters.py", line 498, in _shutdown_workers_and_reset_iterator
    dataloader._iterator = None

Well... We're nearly there. It looks like advancing 1 epoch calls self.reset() on the DataFetcher itself, which then resets the DataLoader and leads to our problem.

Indeed, when checking AbstractDataFetcher, we have this:

class AbstractDataFetcher(...):
    def __iter__(self) -> Generator[Tuple[Any, bool], None, None]:
        if self.dataloader is None:
            raise MisconfigurationException("The iterate hasn't been provided. HINT: Did you call setup function ?.")
-->     self.reset()
        self.dataloader_iter = iter(self.dataloader)
        self._apply_patch()
        self.prefetching(self.prefetch_batches)
        return self

# And iter(AbstractDataFetcher) is called here, in utilities.py:
def _update_dataloader_iter(data_fetcher: AbstractDataFetcher, batch_idx: int) -> Iterator:
    """Attach the dataloader."""
    if not isinstance(data_fetcher, DataLoaderIterDataFetcher):
        # restore iteration
-->     dataloader_iter = enumerate(data_fetcher, batch_idx)
    else:
        dataloader_iter = iter(data_fetcher)
    return dataloader_iter

So... I guess we found the problem? Each time a new epoch runs, it calls self.advance on the FitLoop/EvalLoop, which then calls self.on_run_start. on_run_start will create the _dataloader_iter (which is a normal behavior), which itself calls _update_dataloader_iter. This function, through enumerate, calls AbstractDataFetcher.__iter__ which calls self.reset(), entirely reloading the DataLoader.

Notice that self.reset() is also called in the __init__ method of AbstractDataFetcher, for setup purposes.

I just commented the self.reset() line of AbstractDataLoader, located at line 198 of pytorch_lightning/utilities/fetching.py. While it speeds up the code a lot, and the entire training seems to work, it probably breaks a bunch of things ? I'd wait until the Lightning team fixes the problem before trying anything serious.

What a ride.

a-kore commented 2 years ago

Nice find, I'm on an older version (1.3.7) with a project I'm working on and I can't find a "fetching.py" under utilities. It must be something fairly new, maybe it was meant to be inside a conditional for "reload_dataloaders_every_epoch".

carmocca commented 2 years ago

Should have been fixed by #10434 which landed with the 1.5.1 release

TheMrZZ commented 2 years ago

Tested the new 1.5.1 release today, looks like performance is back on track. Thanks to everyone!

tchaton commented 2 years ago

Dear @TheMrZZ,

Thanks for your investigation and happy we solved this ugly bug.

Best, T.C

isvogor-foi commented 2 years ago

Hi everyone... This topic is very interesting as I'm striking the same issue. I'm comparing the same implementation, in Torch and Lightning. I came across this post, so I noticed I was using the ancient Lightning 1.4.7, so I updated to 1.5.9. I repeated my test and nothing changed... Torch was still significantly faster than lightning. So, as @TheMrZZ suggested, I commented reset in the __iter__ function, and repeated the test. Sure enough, Lightning was... lightning fast now! I'm loading images so I got the following:

Apparently, the performance issue has been fixed in 1.5.1, however, it seems that with 1.5.9 the reset line is still here. So, I'm curious, why do we need to reset the data fetcher, after each epoch?

carmocca commented 2 years ago

@isvogor-foi Are you saying you are again experiencing this problem with version 1.5.9 but not 1.5.1?

If so, can you try the versions in-between and report back your findings?

isvogor-foi commented 2 years ago

@carmocca Hi, well, I didn't try 1.5.1, I tried only 1.4.7, and 1.5.9. I'll see whether I can try 1.5.1 and get back at you with this!

isvogor-foi commented 2 years ago

@carmocca Hi... So I ran a test with 1.5.1. The performance issues are there... A simple resnet18+imagenet with 15000 images, for 15 epochs on V100, 4 workers, batch size 256, prefetch factor 2. The plot below shows the img/s loading. As for, the runtime:

image

However, as already said by @TheMrZZ , removing the self.reset in __iter__ of fetching.py changes everything. Lightning performance increases multifold and outperforms Torch.

So, the question remains, why is self.reset necessary if it deleterious the performance so much?

carmocca commented 2 years ago

@isvogor-foi Happy to look at the issue if you share that vanilla lightning example

isvogor-foi commented 2 years ago

@carmocca Sure, vanilla - meaning, it's the one taken from official implementation:

Lightning: https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pl_examples/domain_templates/imagenet.py Torch: https://github.com/pytorch/examples/blob/master/imagenet/main.py

Try running those on the same device, with the same image dataset, batch size, etc. I've tried on my machine, AWS instance, even Colab, and Torch is always much better.

amin-nejad commented 2 years ago

Sounds like this issue should be reopened

isvogor-foi commented 2 years ago

@amin-nejad I agree. @TheMrZZ shall we reopen? @carmocca did you test?

carmocca commented 2 years ago

Hi @isvogor-foi!

I had a look but I'm not observing any speed differences after commenting reset.

Can you describe exactly what changes are you making? Is the same behaviour reproducible in master? You can also reach me in our Slack if you find that easier.

isvogor-foi commented 2 years ago

@carmocca This is very curious. There should be some difference. So it should take effect if you use multiple epochs, e.g. 50. It should be faster since the DataLoader will not be reset after the epoch, and recreated. This recreation is usually expensive. Do you know whether you're using the "fork" or "spawn" setting?

Aha, it's also important not to use LightningDataModule.

isvogor-foi commented 2 years ago

I downloaded the last version 1.6.0dev and saw there are some changes. Also I retired with MNIST, and (@TheMrZZ) seems that commenting reset() is dangerous, it only runs for the first epoch, and then kills other epochs. That said, in that case Torch in my case runs much faster. I use a custom and same dataloader for both, on this example: Lightning: https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pl_examples/domain_templates/imagenet.py Torch: https://github.com/pytorch/examples/blob/master/imagenet/main.py If I find anything else, I'll report back.

carmocca commented 2 years ago

Here are some simple time results:

This was using 1 NVIDIA GeForce RTX 3090 and PyTorch Lightning 1.6.0dev, commit 8394770d4afa5480f881229b150ac44eaa8c41b0, torch==1.10.1, torchmetrics==0.7.0, torchvision==0.11.2, Python 3.8.12.

PyTorch Lightning

real    1m28.276s
user    5m29.551s
sys     0m34.544s

I used https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pl_examples/domain_templates/imagenet.py verbatim, called with

time python imagenet.py fit --model.data_path /home/imagenet/data --trainer.limit_train_batches=100 --trainer.limit_val_batches=0 --trainer.max_epochs=2

PyTorch

real    1m29.662s
user    5m15.496s
sys     0m31.530s

I used https://github.com/pytorch/examples/blob/master/imagenet/main.py with the following changes applied to disable validation, and stop at 100 batches:

249c249
<         acc1 = validate(val_loader, model, criterion, args)
---
>         #acc1 = validate(val_loader, model, criterion, args)
252,253c252,253
<         is_best = acc1 > best_acc1
<         best_acc1 = max(acc1, best_acc1)
---
>         is_best = False
>         best_acc1 = best_acc1
281a282,283
>         if i == 100:
>             break
time python torch_imagenet.py /home/imagenet/data --epochs=2 --gpu=0

So basically the same speed.

isvogor-foi commented 2 years ago

Hm... very interesting. I did the same with mnist example that comes in the lightning examples. And noticed that Torch was only slightly better in my case, a second or so. However, I've just dismantled the advance call in training_epoch_loop.py. Not to explain details, I've added some arrows to indicate the execution timeline.

lightning

So the training loop is this:

with self.trainer.profiler.profile("run_training_batch"):
    batch_output = self.batch_loop.run(batch, batch_idx)

and before it, this part takes a lot of time:

  response = self.trainer.call_hook("on_train_batch_start", batch, batch_idx, **extra_kwargs)
  if response == -1:
      self.batch_progress.increment_processed()
      raise StopIteration

and after it there is another long call:

  self.trainer.call_hook("on_train_batch_end", batch_end_outputs, batch, batch_idx, **extra_kwargs)
  self.trainer.call_hook("on_batch_end")
  self.trainer.logger_connector.on_batch_end()

Therefore excessive logging can slow it down.

So, in my particular case, I was using lightning 1.5.9 with one V100, and the aforementioned hook calls seem to build up over time, so in my long experiment with 100 epochs, with 256 batch size, and 35k images Torch performs better. Using just 1 GPU. I didn't try multiple GPUs. I think we can leave it at that.

@carmocca thanks!

carapas commented 2 years ago

Hi @isvogor-foi, I the same issue as reported here and after updating pytorch-lightning I didn't see any improvement either. After reading OP blog: https://medium.com/@florian-ernst/finding-why-pytorch-lightning-made-my-training-4x-slower-ae64a4720bd1, I noticed that I was missing the persistent_workers=True flag on my DataLoader:

# My data Loader parameters
DataLoader(
  train_dataset, batch_size=64, shuffle=True, num_workers=n_workers,
  persistent_workers=True, pin_memory=True,
)

Hopefully this will help you! Performance was much improved for me.

JonathanSum commented 2 years ago

Just asking, how many gpu did you use? were you using colab? Thx for answering?

isvogor-foi commented 2 years ago

Just asking, how many gpu did you use? were you using colab? Thx for answering?

Nothing fancy, just a single V100. As for colab, yep.. I also tried Colab. The results were similar, but I couldn't perform an extensive test, as Colab limits usage.

mayiran1999 commented 2 years ago

Occurred the same problem on pl 1.5.10...... I have to refactor my code using only pytorch...

carmocca commented 2 years ago

@mayiran1999 Have you tried the 1.7.0 release?

KameniAlexNea commented 2 years ago

I have similar issue with version 1.7.1 and torch 1.12.1+cu102 using more than one worker (3)

And my trainer is taking a few seconds before ending.

gikok commented 2 years ago

This issue appears to be back in the latest version of PyTorch Lightning (1.7.7), I was training the model on a rather small dataset so each epoch would take only ~1 min, but the trainer would take a ~1m40s pause between each epoch. Using torch 1.9.0+cu111.

It went away after downgrading to version 1.5.1

carmocca commented 2 years ago

Can you provide a reproducible snippet? Alternatively, you could try git bisecting to find out which commit causes this

lminer commented 2 years ago

I'm also having this issue with num_workers=0, ddp with 4 GPUs, lightning 1.7.7 and torch 1.12.0. Interestingly, it only happens if I try to cache my entire dataset during the setup method of the datamodule, which takes up about 150 gigs out of 250. If I load data in on the fly (which is unfortunately much slower) there is no pause between epochs. The problem is not present in 1.6.5, but is there in 1.7.0 and above.

lminer commented 2 years ago

@carmocca I ran git bisect and this is what I get:

cd01856ffcbbf4d7b6dc494fef3d63cfc2eb6563 is the first bad commit
commit cd01856ffcbbf4d7b6dc494fef3d63cfc2eb6563
Author: Rohit Gupta <rohitgr1998@gmail.com>
Date:   Tue May 3 17:57:06 2022 +0530

    Add `LightningDataModule.load_from_checkpoint` to load datamodules directly from checkpoint (#12550)

    Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
    Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
    Co-authored-by: otaj <ota@grid.ai>

 CHANGELOG.md                                       |   2 +
 docs/source/common/checkpointing_basic.rst         |   1 +
 pytorch_lightning/core/datamodule.py               |  76 +++++++-
 pytorch_lightning/core/saving.py                   | 214 ++++++++++++---------
 .../trainer/connectors/checkpoint_connector.py     |  20 +-
 tests/models/test_hparams.py                       |  99 +++++++---
 6 files changed, 279 insertions(+), 133 deletions(-)

I'm very confused as to why this commit is causing the issue. I thought maybe it's an issue with checkpointing, so I removed the checkpointing callback, but still get the pause.

carmocca commented 2 years ago

Thank you. That is very helpful. cc @rohitgr7

Are you using self.save_hyperparameters()?

lminer commented 2 years ago

Yes I am using save hyperparameters. Happy to help!

On Sat, Oct 22, 2022 at 2:54 AM Carlos Mocholí @.***> wrote:

Thank you. That is very helpful. cc @rohitgr7 https://github.com/rohitgr7

Are you using self.save_hyperparameters()?

— Reply to this email directly, view it on GitHub https://github.com/Lightning-AI/lightning/issues/10389#issuecomment-1287698434, or unsubscribe https://github.com/notifications/unsubscribe-auth/ABHDTFJ2MEX4NYBHMANHCF3WEO2UPANCNFSM5HPZAV3A . You are receiving this because you commented.Message ID: @.***>

rohitgr7 commented 1 year ago

@lminer this is not good. Can you share a reproducible script using BoringModel to check this issue?

lminer commented 1 year ago

Unfortunately I don’t have the time to make a script right now. My guess, though, is that you could duplicate this behavior by creating a very large in-memory dataset (mine is around 40 gigs per gpu) and then running save_hyperparameters in the data module.

On Tue, Oct 25, 2022 at 2:08 AM Rohit Gupta @.***> wrote:

@lminer https://github.com/lminer this is not good. Can you share a reproducible script using BoringModel https://colab.research.google.com/github/Lightning-AI/lightning/blob/master/examples/pl_bug_report/bug_report_model.ipynb to check this issue?

— Reply to this email directly, view it on GitHub https://github.com/Lightning-AI/lightning/issues/10389#issuecomment-1290231290, or unsubscribe https://github.com/notifications/unsubscribe-auth/ABHDTFJNOFJK76BQXG2SHX3WE6PPNANCNFSM5HPZAV3A . You are receiving this because you were mentioned.Message ID: @.***>

jzazo commented 1 year ago

I wouldn't know how to make this reproducible either, but I have pinned PL to 1.6.5 in my project because I observe the same behavior as @lminer on 1.7.* with very slow epochs. I also have very large memory usage, although I am using single gpu and I am not calling save_hyperparameters at all. Could there be other underlying factor?

rohitgr7 commented 1 year ago

okay... then there might be something wrong here. For starters, I'd suggest using a simple profiler (using profiler='simple' to test things out. Try turning of checkpointing completely (using enable_checkpoint_callback=False).

make sure to use limit_train_batches/limit_val_batches so that you don't have to wait for the whole dataset to complete. You can also just return from training_step/validation_step without doing any model computation :)

lminer commented 1 year ago

@rohitgr7 I think I've found the issue. Previously, I was passing the datasets into the LightningDataModule, running save_hyperparameters and then accessing them via self.hparams["train_dataset"], self.hparams.["validation_dataset"], etc. It seems as if the entire cached dataset was being cached in this case.

The fix is to explicitly make these modules instance variables and then run self.save_hyperparameters(ignore=["train_dataset", "validation_dataset", "test_dataset"])

rohitgr7 commented 1 year ago

great! datasets are not hparams so you should be careful with what you are saving in them.

@jzazo is it a similar issue in your case as well? can you check your datamodule hparams?

jzazo commented 1 year ago

I am not using datamodules, I am passing dataloaders directly to the trainer. I am also not using save_hyperparameters functionality. But I will check whatever I am attaching to the module in case there is any serialization happening. I will profile the training as well. I think I will have some time next week to verify this and I will report back.

awaelchli commented 1 year ago

@lminer You are correct, this is the right way to handle this when using save_hyperparameters().

@jzazo My recommendation is to check whether checkpointing is taking a long time (this happens between epochs). You could check by simply setting enable_checkpointing=Falsein the Trainer. Please feel free to report back in a new issue with your findings. Or ping us on slack if you need more guidance on debugging this.

is-jlehrer commented 1 year ago

This still seems to be the case, especially with dataloader startup time

Jason94 commented 1 year ago

Is this still fixed on 2.0.4? I'm still seeing this behavior. This code runs lightning fast (albeit with warnings about not having any workers):

    train_data_loader = DataLoader(
        NameDataset("data/training.csv", char_to_int),
        batch_size=TRAIN_BATCH_SIZE,
        shuffle=True,
        # num_workers=4,
    )
    eval_data_loader = DataLoader(
        NameDataset("data/eval.csv", char_to_int, debug=False),
        batch_size=TRAIN_BATCH_SIZE,
        shuffle=False,
        # num_workers=4,
    )

    lightning_model = PlContactEncoder(
        model, criterion, SIMILARITY_METRIC(0.5, return_distance=True), LEARNING_RATE
    )
    lightning_model.to(device)

    trainer = pl.Trainer(max_epochs=N_EPOCHS, max_steps=50)
    trainer.fit(
        model=lightning_model,
        train_dataloaders=train_data_loader,
        val_dataloaders=eval_data_loader,
    )

Adding the workers makes the warnings go away, but freezes ~15 seconds before validation or a new epoch:

    train_data_loader = DataLoader(
        NameDataset("data/training.csv", char_to_int),
        batch_size=TRAIN_BATCH_SIZE,
        shuffle=True,
        num_workers=4,
    )
    eval_data_loader = DataLoader(
        NameDataset("data/eval.csv", char_to_int, debug=False),
        batch_size=TRAIN_BATCH_SIZE,
        shuffle=False,
        num_workers=4,
    )

    lightning_model = PlContactEncoder(
        model, criterion, SIMILARITY_METRIC(0.5, return_distance=True), LEARNING_RATE
    )
    lightning_model.to(device)

    trainer = pl.Trainer(max_epochs=N_EPOCHS, max_steps=50)
    trainer.fit(
        model=lightning_model,
        train_dataloaders=train_data_loader,
        val_dataloaders=eval_data_loader,
    )
aktgpt commented 9 months ago

Hi,

This issue still persists in 2.1.3. I'm directly passing the dataloaders to the Trainer train_dataloader=DataLoader(train_dataset,batch_size=64,collate_fn=collate_fn_train,drop_last=True,shuffle=True,pin_memory=True,num_workers=8,prefetch_factor=8,persistent_workers=True,) but the training freezes for a few seconds after each epoch. When I ran the advanced profiler in run_training_epoch, I saw that reset is called every epoch (dataloader.py:1086(_reset) and training_epoch_loop.py:143(reset)). The training is extremely slow if I set num_workers=0.

Is there a fix/workaround for this?

aktgpt commented 9 months ago

Is there any update on this? It seems like I'm not the only one facing this issue. I started using PTL since 2.x and would prefer not to downgrade to 1.5.1 to make this issue go away. Is there a 2.x version that doesn't have this issue? @awaelchli @Borda @carmocca

leventt commented 7 months ago

This is indeed an issue with the latest version still... I have checked a few things such as persistent workers and so on and made a small comparison against vanilla PyTorch with a nested for loop and I can confirm "lightning" may be an inaccurate way to describe this library right now.

otakudj commented 6 months ago

I think this issue should be reopened. I meet this problem with the pl version 2.0.3