Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.24k stars 3.38k forks source link

Lightning is very slow between epochs, compared to PyTorch. #10389

Closed TheMrZZ closed 1 year ago

TheMrZZ commented 2 years ago

I converted some Pytorch code to Lightning. The dataset is loaded lazily by the train & eval dataloaders.

However, when moving the code to Lightning, I noticed a huge slowdown. After digging around, I noticed that there was a ~10 seconds delay between each epoch. For comparison, on my vanilla Pytorch, an epoch takes ~4s.

I first thought it was a data loading problem, but during the 10s delay, no data is loaded (at least that's what my print tell me).

I think the issue is related to the number of workers, because setting n_workers=0 solves the problem (but is slower in the end, since only one worker is not enough). I know starting workers is slow, however I have persistent_workers=True and this does not happen in normal Pytorch. My data loaders also have pin_memory=True (removing pin_memory does not solve the problem).

Since this is company code, I cannot disclose the before/after, but I'll try to "anonymize" some code if necessary. Here is the lightning module:

class RawModule(pl.LightningModule):
    def __init__(self):
        super(RawModule, self).__init__()

        self.encoder1 = nn.Sequential(...)
        self.encoder2 = nn.Sequential(...)

    def forward(self, data1, data2):
        result1 = self.encoder1(data1)
        result2 = self.encoder2(data2)

        result1 = result1 .view(result1 .size(0), -1)
        result2 = result2 .view(result2 .size(0), -1)

        result1 = F.normalize(result1 , p=2, dim=1)
        result2 = F.normalize(result2 , p=2, dim=1)

        return result1, result2

    def calculate_loss(self, batch):
        x, r, y = batch
        a, v = self.forward(r, x)

        d = nn.functional.cosine_similarity(a, v)
        loss = logloss(d.unsqueeze(1), y)

        return loss

class Module(RawModule):
    def training_step(self, batch, batch_idx):
        loss = self.calculate_loss(batch)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        loss = self.calculate_loss(batch)
        self.log("validation_loss", loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-5)
        return optimizer

if __name__ == '__main__':
    # stuff...

    train_loader = data_utils.DataLoader(
        train_dataset, batch_size=256, shuffle=True,
        num_workers=5, persistent_workers=True,
        pin_memory=True,
    )

    val_loader = data_utils.DataLoader(
        test_dataset, batch_size=256,
        num_workers=2, persistent_workers=True,
        pin_memory=True,
    )

    # Model
    load_from_pytorch = True

    if checkpoint_path is None:
        model = Module()

        if load_from_pytorch:
            if not checkpoint_path:
                raise ValueError("Please provide a checkpoint path")
            model.load_state_dict(torch.load(checkpoint_path)['state_dict'])
    else:
        model = Module.load_from_checkpoint(checkpoint_path)

    trainer = pl.Trainer(
        gpus=1,
        max_epochs=5,
        check_val_every_n_epoch=10,
        log_every_n_steps=5,
    )
    trainer.fit(model, train_loader, val_loader)

Here is the result of profiler="simple":

Action                                  |  Mean duration (s)    |Num calls              |  Total time (s)       |  Percentage %         |
----------------------------------------------------------------------------------------------------------------------------------------
Total                                   |  -                    |_                      |  48.813               |  100 %                |
----------------------------------------------------------------------------------------------------------------------------------------
run_training_epoch                      |  27.922               |1                      |  27.922               |  57.202               |
fetch_next_sanity_check_batch           |  4.4013               |3                      |  13.204               |  27.05                |
get_sanity_check_batch                  |  4.4013               |3                      |  13.204               |  27.05                |
fetch_next_train_batch                  |  1.2734               |10                     |  12.734               |  26.087               |
get_train_batch                         |  1.2734               |10                     |  12.734               |  26.087               |
run_training_batch                      |  0.47733              |9                      |  4.296                |  8.8009               |
optimizer_step_with_closure_0           |  0.40089              |9                      |  3.608                |  7.3915               |
validation_step                         |  0.664                |2                      |  1.328                |  2.7206               |
evaluation_step_and_end                 |  0.664                |2                      |  1.328                |  2.7206               |
training_step_and_backward              |  0.12644              |9                      |  1.138                |  2.3313               |
backward                                |  0.096889             |9                      |  0.872                |  1.7864               |
training_step                           |  0.029556             |9                      |  0.266                |  0.54494              |
model_forward                           |  0.029556             |9                      |  0.266                |  0.54494              |
on_train_start                          |  0.016                |1                      |  0.016                |  0.032778             |

Here is the result of profiler="advanced": https://pastebin.com/q3C5P826.

Finally, here is a video demonstrating the problem. I'm printing each piece of data loading, to prove it's not the issue. https://user-images.githubusercontent.com/30944236/140587623-ae184fa3-370a-42be-8593-200026d11ba4.mp4

Random informations:

cc @tchaton @rohitgr7 @borda @akihironitta

jordan7186 commented 5 months ago

Perhaps this issue persists? Still experiencing similar freeze for validations, version is 2.2.4

Lunamos commented 2 months ago

This issue still exists in version 2.3.3. With higher num_workers, the time between epochs is significantly longer. I tested the influence of saving checkpointing or hyperparameters and finds that these settings do not affect the runtime. The bug is the same as the initial finding which is caused by dataloader.

An easy fix can be setting num_workers=0 or add persistent_workers=True when instantiating dataloader.

We let such an ugly bug that has been fixed before seriously affect the operation speed of the entire Lightning 2.0.

I think this issue need to be reopened and fixed as soon as possible. @awaelchli

leventt commented 2 months ago

persistent_workers=True

did not actually worked around the issue when I was testing

On Wed, Jul 31, 2024 at 21:23 Jin Zehao @.***> wrote:

This issue still exists in version 2.3.3. With higher num_workers, the time between epochs is significantly longer. I tested the influence of saving checkpointing or hyperparameters and finds that these settings do not affect the runtime. The bug is the same as the initial finding which is caused by dataloader.

An easy fix can be setting num_workers=0 or add persistent_workers=True when instantiating dataloader.

We let such an ugly bug that has been fixed before seriously affect the operation speed of the entire Lightning 2.0.

I think this issue need to be reopened and fixed as soon as possible. @awaelchli https://github.com/awaelchli

— Reply to this email directly, view it on GitHub https://github.com/Lightning-AI/pytorch-lightning/issues/10389#issuecomment-2261886326, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAP5M3FKYTFCQBF73YGK6JLZPGS2PAVCNFSM5HPZAV3KU5DIOJSWCZC7NNSXTN2JONZXKZKDN5WW2ZLOOQ5TEMRWGE4DQNRTGI3A . You are receiving this because you commented.Message ID: @.***>

KalinNonchev commented 2 months ago

Hi, version v2.4 still has this issue, and the recommendations mentioned are not working.

meakbiyik commented 1 month ago

Folks, I might have a solution.

TL;DR: use OMP_NUM_THREADS=1 MKL_NUM_THREADS=1

The problem stems from a combination of weird torch defaults, using Slurm or some comparable scheduler tool without containerization, and a large num_worker count.

By default, torch uses as many threads as possible for interop and intraop operations. This "as many" is determined by the number of CPU cores in your system (see here).

If you are using a scheduler such as Slurm, torch will think that you have access to all the CPUs in your machine (since the node resources are visible to the job) even if you have limited the number of cores allocated to the job. Therefore, e.g. in a 100-core node, torch will spawn hundreds of threads for each worker, suffocating your system.

The solution is to reduce the number of threads that torch can spawn using the above-mentioned environment variables (1 is not required, I believe, but keeping it somewhat close to the actual number of CPUs would be smart). Alternatively, use containerization, people. Don't let Slurm pull you into its evil ways.

In my experiments, this seems to resolve a couple deadlocks I have been hitting, and considerably improve the behavior for this particular issue. There is still some delay when switching between train and validation workers, which might be a bug on lightning side (verification needed), but at least the training is now manageable.

This might be the same issue as #4450, or pretty much most other non-reproducible performance issues in torch/lightning repos.

amirshamaeisynex commented 1 week ago

I tracked down my problem to evaluation_loop.py in PL. This line of code iter(data_fetcher) # creates the iterator inside the fetcher takes too much to run. I guess data_fetcher is culprit here.

amirshamaeisynex commented 1 week ago

in my case persistent_workers=True solved the issue

Borda commented 1 week ago

in my case persistent_workers=True solved the issue

Sounds good, would you give a try?