Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.27k stars 3.38k forks source link

fp16 is almost 10x slower then fp32 in my case? #16896

Closed npurson closed 1 year ago

npurson commented 1 year ago

Bug description

Initially I trained the model with full precision, and everything seems to work fine.

Epoch 0:   0%|▌                               | 21/4649 [00:55<3:24:42,  2.65s/it, loss=42.4, v_num=14]

However, after setting precision=16 in pl.Trainer(), the time consumption is 10x slower.

Epoch 0:   0%|                                | 2/4649 [00:53<34:40:02, 26.86s/it, loss=45.7, v_num=15]

After timing each component, it turns out that there are no significant changes in the duration of pl.LightningModule.training_step(). Thus, I suppose the problem may lie in the lightning framework itself? What might be the causes and how should I solve them?

How to reproduce the bug

Since the duration of training_step() doesn't change, I only attach the code of LightningModule, training scripts and related configs as follows:

import torch.optim as optim
import lightning.pytorch as pl
from torch.cuda.amp import autocast

from ... import evaluation

class PLModelInterface(pl.LightningModule):

    def __init__(self, optimizer, scheduler, evaluator, **kwargs):
        super().__init__()
        self.optimizer_cfg = optimizer
        self.scheduler_cfg = scheduler
        self.train_evaluator = getattr(evaluation, evaluator.type)(**evaluator.cfgs)
        self.test_evaluator = getattr(evaluation, evaluator.type)(**evaluator.cfgs)
        if 'class_names' in kwargs:
            self.class_names = kwargs['class_names']
        ...  # define your model afterward

    def forward(self, x):
        ...

    def losses(self, pred, y):
        ...

    def _step(self, batch, evaluator=None):
        x, y = batch
        pred = self(x)
        with autocast(enabled=False):
            loss = self.losses(pred, y)
        if evaluator:
            evaluator.update(pred, y)
        return loss

    def training_step(self, batch, batch_idx):
        loss = self._step(batch, self.train_evaluator)
        self.log('train_loss', {'loss_total': sum(loss.values()), **loss})
        return sum(list(loss.values())) if isinstance(loss, dict) else loss

    def validation_step(self, batch, batch_idx):
        self._shared_eval(batch, 'val')

    def test_step(self, batch, batch_idx):
        self._shared_eval(batch, 'test')

    def _shared_eval(self, batch, prefix):
        loss = self._step(batch, self.test_evaluator)
        # Lightning automatically accumulates the metric and averages it
        # if `self.log` is inside the `validation_step` and `test_step`
        self.log(f'{prefix}_loss', loss, sync_dist=True)

    def training_epoch_end(self, outputs):
        self._log_metrics(self.train_evaluator, 'train')

    def validation_epoch_end(self, outputs):
        self._log_metrics(self.test_evaluator, 'val')

    def _log_metrics(self, evaluator, prefix=None):
        metrics = evaluator.compute()
        iou_per_class = metrics.pop('iou_per_class')
        if prefix:
            metrics = {'_'.join((prefix, k)): v for k, v in metrics.items()}
        self.log_dict(metrics, sync_dist=True)

        if hasattr(self, 'class_names'):
            self.log(prefix + '_iou_per_cls',
                     {c: s.item()
                      for c, s in zip(self.class_names, iou_per_class)},
                     sync_dist=True)
        evaluator.reset()

    def configure_optimizers(self):
        optimizer_cfg = self.optimizer_cfg
        scheduler_cfg = self.scheduler_cfg
        if 'paramwise_cfg' in optimizer_cfg:
            paramwise_cfg = optimizer_cfg.paramwise_cfg
            params = []
            pgs = [[] for _ in paramwise_cfg]

            for k, v in self.named_parameters():
                for i, pg_cfg in enumerate(paramwise_cfg):
                    if 'name' in pg_cfg and pg_cfg.name in k:
                        pgs[i].append(v)
                    # USER: Customize more cfgs if needed
                    else:
                        params.append(v)
        else:
            params = self.parameters()
        optimizer = getattr(optim, optimizer_cfg.type)(params, **optimizer_cfg.cfgs)
        if 'paramwise_cfg' in optimizer_cfg:
            for pg, pg_cfg in zip(pgs, paramwise_cfg):
                cfg = {}
                if 'lr_mult' in pg_cfg:
                    cfg['lr'] = optimizer_cfg.cfgs.lr * pg_cfg.lr_mult
                # USER: Customize more cfgs if needed
                optimizer.add_param_group({'params': pg, **cfg})
        scheduler = getattr(optim.lr_scheduler, scheduler_cfg.type)(optimizer, **scheduler_cfg.cfgs)
        if 'interval' in scheduler_cfg:
            scheduler = {'scheduler': scheduler, 'interval': scheduler_cfg.interval}
        return {'optimizer': optimizer, 'lr_scheduler': scheduler}
import sys

sys.path.append('.')  # run from project root

import os
import hydra
import lightning.pytorch as pl
from omegaconf import DictConfig, OmegaConf

from ssc3d import build_data_loaders, build_from_configs, models

@hydra.main(version_base=None, config_path='../configs', config_name='config')
def main(cfg: DictConfig):
    if os.environ.get('LOCAL_RANK', 0) == 0:
        print(OmegaConf.to_yaml(cfg))
    cfg, callbacks = build_from_configs(cfg)

    dls, meta_info = build_data_loaders(cfg.data)  # dls is a list of dataloaders, I didn't use pl.DataModule
    model = getattr(models, cfg.model.type)(**cfg.model.cfgs, **cfg.solver, **meta_info)
    trainer = pl.Trainer(**cfg.trainer, **callbacks)
    trainer.fit(model, *dls)  # resume training by `ckpt_path=`

if __name__ == '__main__':
    main()
trainer:
  devices: 4
  accelerator: gpu
  strategy: ddp  # ddp_find_unused_parameters_false is experimental and subject to change
  sync_batchnorm: True
  precision: 16

Error messages and logs

No response

Environment

Current environment ``` #- Lightning Component: Trainer, LightningModule, torchmetrics #- PyTorch Lightning Version: 1.9.0 #- PyTorch Version: 1.10.2+cu111 #- Python version: 3.9.15 #- OS: Linux #- CUDA/cuDNN version: 11.1 #- How you installed Lightning: pip * CUDA: - GPU: - NVIDIA GeForce RTX 3090 - available: True - version: 11.1 * Lightning: - lightning: 1.9.0 - lightning-cloud: 0.5.19 - lightning-utilities: 0.6.0.post0 - torch: 1.10.2+cu111 - torchmetrics: 0.11.0 - torchvision: 0.11.3+cu111 * Packages: - aiohttp: 3.8.3 - aiosignal: 1.3.1 - antlr4-python3-runtime: 4.9.3 - anyio: 3.6.2 - arrow: 1.2.3 - async-timeout: 4.0.2 - attrs: 22.2.0 - beautifulsoup4: 4.11.1 - blessed: 1.19.1 - certifi: 2022.9.24 - charset-normalizer: 2.1.1 - click: 8.1.3 - contourpy: 1.0.7 - croniter: 1.3.8 - cycler: 0.11.0 - dateutils: 0.6.12 - deepdiff: 6.2.3 - dnspython: 2.3.0 - docopt: 0.6.2 - einops: 0.6.0 - email-validator: 1.3.1 - fastapi: 0.88.0 - filelock: 3.9.0 - flake8: 6.0.0 - fonttools: 4.38.0 - frozenlist: 1.3.3 - fsspec: 2023.1.0 - fvcore: 0.1.5.post20221221 - h11: 0.14.0 - httpcore: 0.16.3 - httptools: 0.5.0 - httpx: 0.23.3 - huggingface-hub: 0.12.0 - hydra-core: 1.3.1 - idna: 3.4 - imageio: 2.25.0 - importlib-resources: 5.12.0 - inquirer: 3.1.2 - iopath: 0.1.10 - itsdangerous: 2.1.2 - jinja2: 3.1.2 - kiwisolver: 1.4.4 - lightning: 1.9.0 - lightning-cloud: 0.5.19 - lightning-utilities: 0.6.0.post0 - llvmlite: 0.39.1 - markdown-it-py: 2.1.0 - markupsafe: 2.1.2 - matplotlib: 3.7.0 - mccabe: 0.7.0 - mdurl: 0.1.2 - multidict: 6.0.4 - networkx: 3.0 - numba: 0.56.4 - numpy: 1.23.4 - omegaconf: 2.3.0 - ordered-set: 4.1.0 - orjson: 3.8.5 - packaging: 22.0 - pillow: 9.3.0 - pip: 22.2.2 - pipreqs: 0.4.11 - portalocker: 2.7.0 - psutil: 5.9.4 - pycodestyle: 2.10.0 - pydantic: 1.10.4 - pyflakes: 3.0.1 - pygments: 2.14.0 - pyjwt: 2.6.0 - pyparsing: 3.0.9 - python-dateutil: 2.8.2 - python-dotenv: 0.21.1 - python-editor: 1.0.4 - python-multipart: 0.0.5 - pytz: 2022.7.1 - pywavelets: 1.4.1 - pyyaml: 6.0 - readchar: 4.0.3 - requests: 2.28.2 - rfc3986: 1.5.0 - rich: 13.2.0 - scikit-image: 0.19.3 - scipy: 1.10.0 - setuptools: 65.5.0 - six: 1.16.0 - sniffio: 1.3.0 - soupsieve: 2.3.2.post1 - starlette: 0.22.0 - starsessions: 1.3.0 - tabulate: 0.9.0 - termcolor: 2.2.0 - tifffile: 2023.1.23.1 - timm: 0.6.12 - torch: 1.10.2+cu111 - torchmetrics: 0.11.0 - torchvision: 0.11.3+cu111 - tqdm: 4.64.1 - traitlets: 5.8.1 - typing-extensions: 4.4.0 - ujson: 5.7.0 - urllib3: 1.26.14 - uvicorn: 0.20.0 - uvloop: 0.17.0 - watchfiles: 0.18.1 - wcwidth: 0.2.6 - websocket-client: 1.4.2 - websockets: 10.4 - wheel: 0.37.1 - yacs: 0.1.8 - yapf: 0.32.0 - yarg: 0.1.9 - yarl: 1.8.2 - zipp: 3.14.0 * System: - OS: Linux - architecture: - 64bit - ELF - processor: x86_64 - python: 3.9.15 - version: #1 SMP Thu Nov 8 23:39:32 UTC 2018 ```

More info

No response

cc @tchaton @borda

npurson commented 1 year ago

It turns out that [Strategy]SingleDeviceStrategy.backward is the bottleneck according to Trainer(profiler='simple').

----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|  Action                                                                                                                                                                       |  Mean duration (s)|  Num calls            |  Total time (s)       |  Percentage %         |
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|  Total                                                                                                                                                                        |  -               |  391                   |  36.704               |  100 %                |
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|  run_training_epoch                                                                                                                                                           |  18.419          |  1                     |  18.419               |  50.182               |
|  run_training_batch                                                                                                                                                           |  1.8746          |  5                     |  9.3728               |  25.536               |
|  [LightningModule]GeometryTransformer.optimizer_step                                                                                                                          |  1.8704          |  5                     |  9.3518               |  25.479               |
|  [TrainingEpochLoop].train_dataloader_next                                                                                                                                    |  1.1561          |  5                     |  5.7803               |  15.748               |
|  [Strategy]SingleDeviceStrategy.backward                                                                                                                                      |  1.0126          |  5                     |  5.0628               |  13.794               |
|  [EvaluationEpochLoop].None_dataloader_idx_0_next                                                                                                                             |  2.444           |  2                     |  4.888                |  13.317               |
|  [Strategy]SingleDeviceStrategy.training_step                                                                                                                                 |  0.72914         |  5                     |  3.6457               |  9.9326               |
|  [Strategy]SingleDeviceStrategy.batch_to_device                                                                                                                               |  0.35726         |  7                     |  2.5008               |  6.8134               |
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|  Action                                                                                                                                                                       |  Mean duration (s)|  Num calls            |  Total time (s)       |  Percentage %         |
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|  Total                                                                                                                                                                        |  -               |  391                   |  143.8                |  100 %                |
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|  run_training_epoch                                                                                                                                                           |  125.26          |  1                     |  125.26               |  87.105               |
|  run_training_batch                                                                                                                                                           |  23.238          |  5                     |  116.19               |  80.798               |
|  [LightningModule]GeometryTransformer.optimizer_step                                                                                                                          |  23.233          |  5                     |  116.16               |  80.78                |
|  [Strategy]SingleDeviceStrategy.backward                                                                                                                                      |  22.616          |  5                     |  113.08               |  78.636               |
|  [TrainingEpochLoop].train_dataloader_next                                                                                                                                    |  1.187           |  5                     |  5.9348               |  4.127                |
|  [EvaluationEpochLoop].None_dataloader_idx_0_next                                                                                                                             |  2.7457          |  2                     |  5.4915               |  3.8187               |
|  [Strategy]SingleDeviceStrategy.training_step                                                                                                                                 |  0.5668          |  5                     |  2.834                |  1.9707               |
|  [Strategy]SingleDeviceStrategy.batch_to_device                                                                                                                               |  0.37026         |  7                     |  2.5918               |  1.8023               |

What might the reasons be? 🥲