Closed npurson closed 1 year ago
It turns out that [Strategy]SingleDeviceStrategy.backward
is the bottleneck according to Trainer(profiler='simple')
.
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
| Action | Mean duration (s)| Num calls | Total time (s) | Percentage % |
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
| Total | - | 391 | 36.704 | 100 % |
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
| run_training_epoch | 18.419 | 1 | 18.419 | 50.182 |
| run_training_batch | 1.8746 | 5 | 9.3728 | 25.536 |
| [LightningModule]GeometryTransformer.optimizer_step | 1.8704 | 5 | 9.3518 | 25.479 |
| [TrainingEpochLoop].train_dataloader_next | 1.1561 | 5 | 5.7803 | 15.748 |
| [Strategy]SingleDeviceStrategy.backward | 1.0126 | 5 | 5.0628 | 13.794 |
| [EvaluationEpochLoop].None_dataloader_idx_0_next | 2.444 | 2 | 4.888 | 13.317 |
| [Strategy]SingleDeviceStrategy.training_step | 0.72914 | 5 | 3.6457 | 9.9326 |
| [Strategy]SingleDeviceStrategy.batch_to_device | 0.35726 | 7 | 2.5008 | 6.8134 |
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
| Action | Mean duration (s)| Num calls | Total time (s) | Percentage % |
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
| Total | - | 391 | 143.8 | 100 % |
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
| run_training_epoch | 125.26 | 1 | 125.26 | 87.105 |
| run_training_batch | 23.238 | 5 | 116.19 | 80.798 |
| [LightningModule]GeometryTransformer.optimizer_step | 23.233 | 5 | 116.16 | 80.78 |
| [Strategy]SingleDeviceStrategy.backward | 22.616 | 5 | 113.08 | 78.636 |
| [TrainingEpochLoop].train_dataloader_next | 1.187 | 5 | 5.9348 | 4.127 |
| [EvaluationEpochLoop].None_dataloader_idx_0_next | 2.7457 | 2 | 5.4915 | 3.8187 |
| [Strategy]SingleDeviceStrategy.training_step | 0.5668 | 5 | 2.834 | 1.9707 |
| [Strategy]SingleDeviceStrategy.batch_to_device | 0.37026 | 7 | 2.5918 | 1.8023 |
What might the reasons be? 🥲
Bug description
Initially I trained the model with full precision, and everything seems to work fine.
However, after setting
precision=16
inpl.Trainer()
, the time consumption is 10x slower.After timing each component, it turns out that there are no significant changes in the duration of
pl.LightningModule.training_step()
. Thus, I suppose the problem may lie in the lightning framework itself? What might be the causes and how should I solve them?How to reproduce the bug
Since the duration of
training_step()
doesn't change, I only attach the code of LightningModule, training scripts and related configs as follows:Error messages and logs
No response
Environment
Current environment
``` #- Lightning Component: Trainer, LightningModule, torchmetrics #- PyTorch Lightning Version: 1.9.0 #- PyTorch Version: 1.10.2+cu111 #- Python version: 3.9.15 #- OS: Linux #- CUDA/cuDNN version: 11.1 #- How you installed Lightning: pip * CUDA: - GPU: - NVIDIA GeForce RTX 3090 - available: True - version: 11.1 * Lightning: - lightning: 1.9.0 - lightning-cloud: 0.5.19 - lightning-utilities: 0.6.0.post0 - torch: 1.10.2+cu111 - torchmetrics: 0.11.0 - torchvision: 0.11.3+cu111 * Packages: - aiohttp: 3.8.3 - aiosignal: 1.3.1 - antlr4-python3-runtime: 4.9.3 - anyio: 3.6.2 - arrow: 1.2.3 - async-timeout: 4.0.2 - attrs: 22.2.0 - beautifulsoup4: 4.11.1 - blessed: 1.19.1 - certifi: 2022.9.24 - charset-normalizer: 2.1.1 - click: 8.1.3 - contourpy: 1.0.7 - croniter: 1.3.8 - cycler: 0.11.0 - dateutils: 0.6.12 - deepdiff: 6.2.3 - dnspython: 2.3.0 - docopt: 0.6.2 - einops: 0.6.0 - email-validator: 1.3.1 - fastapi: 0.88.0 - filelock: 3.9.0 - flake8: 6.0.0 - fonttools: 4.38.0 - frozenlist: 1.3.3 - fsspec: 2023.1.0 - fvcore: 0.1.5.post20221221 - h11: 0.14.0 - httpcore: 0.16.3 - httptools: 0.5.0 - httpx: 0.23.3 - huggingface-hub: 0.12.0 - hydra-core: 1.3.1 - idna: 3.4 - imageio: 2.25.0 - importlib-resources: 5.12.0 - inquirer: 3.1.2 - iopath: 0.1.10 - itsdangerous: 2.1.2 - jinja2: 3.1.2 - kiwisolver: 1.4.4 - lightning: 1.9.0 - lightning-cloud: 0.5.19 - lightning-utilities: 0.6.0.post0 - llvmlite: 0.39.1 - markdown-it-py: 2.1.0 - markupsafe: 2.1.2 - matplotlib: 3.7.0 - mccabe: 0.7.0 - mdurl: 0.1.2 - multidict: 6.0.4 - networkx: 3.0 - numba: 0.56.4 - numpy: 1.23.4 - omegaconf: 2.3.0 - ordered-set: 4.1.0 - orjson: 3.8.5 - packaging: 22.0 - pillow: 9.3.0 - pip: 22.2.2 - pipreqs: 0.4.11 - portalocker: 2.7.0 - psutil: 5.9.4 - pycodestyle: 2.10.0 - pydantic: 1.10.4 - pyflakes: 3.0.1 - pygments: 2.14.0 - pyjwt: 2.6.0 - pyparsing: 3.0.9 - python-dateutil: 2.8.2 - python-dotenv: 0.21.1 - python-editor: 1.0.4 - python-multipart: 0.0.5 - pytz: 2022.7.1 - pywavelets: 1.4.1 - pyyaml: 6.0 - readchar: 4.0.3 - requests: 2.28.2 - rfc3986: 1.5.0 - rich: 13.2.0 - scikit-image: 0.19.3 - scipy: 1.10.0 - setuptools: 65.5.0 - six: 1.16.0 - sniffio: 1.3.0 - soupsieve: 2.3.2.post1 - starlette: 0.22.0 - starsessions: 1.3.0 - tabulate: 0.9.0 - termcolor: 2.2.0 - tifffile: 2023.1.23.1 - timm: 0.6.12 - torch: 1.10.2+cu111 - torchmetrics: 0.11.0 - torchvision: 0.11.3+cu111 - tqdm: 4.64.1 - traitlets: 5.8.1 - typing-extensions: 4.4.0 - ujson: 5.7.0 - urllib3: 1.26.14 - uvicorn: 0.20.0 - uvloop: 0.17.0 - watchfiles: 0.18.1 - wcwidth: 0.2.6 - websocket-client: 1.4.2 - websockets: 10.4 - wheel: 0.37.1 - yacs: 0.1.8 - yapf: 0.32.0 - yarg: 0.1.9 - yarl: 1.8.2 - zipp: 3.14.0 * System: - OS: Linux - architecture: - 64bit - ELF - processor: x86_64 - python: 3.9.15 - version: #1 SMP Thu Nov 8 23:39:32 UTC 2018 ```More info
No response
cc @tchaton @borda