Open kopalja opened 3 months ago
Thanks for reporting the issue. Setting precision to '32-true' fixes the problem for me.
Yes but that is not really the solution. In addition the problem might be still present and manifest itself at different training step.
Agreed it's not a fix, but it saved me from having to rewrite my implementation or tell my PI that we had to wait for a bug to be fixed before we could finish our paper.
Looks like it's affected lighting verion 2.3.3
.
Bug description
Hi, I am using PyTorch lightning to implement some new optimization strategies using
automatic_optimization=False
. For certain setting my optimization strategy (usingautomatic_optimization=False
) should yield the same results as using standard optimization process (automatic_optimization=True
). However I could not make it work. My optimization process was returning slightly different results as using default optimization process. After a while I figured out that PyTorch lightning sometimes does not update the model weights when using the defaultautomatic_optimization=True
. I have put together minimal example in which model weights won't get updated on step 5. Model weights also won't get updated when using different hyper-parameters (e.g., batch-size, lr), only at different training step.Am I missing something or does this look like a bug. Thanks!
What version are you seeing the problem on?
v2.4
How to reproduce the bug
Error messages and logs
Environment
Current environment
``` * CUDA: - GPU: - NVIDIA A100-PCIE-40GB - available: True - version: 12.1 * Lightning: - lightning-utilities: 0.11.6 - pytorch-lightning: 2.3.3 - torch: 2.4.0 - torchmetrics: 1.4.1 - torchvision: 0.19.0 * Packages: - absl-py: 2.1.0 - aiohappyeyeballs: 2.3.4 - aiohttp: 3.10.1 - aiosignal: 1.3.1 - asttokens: 2.4.1 - attrs: 24.1.0 - autocommand: 2.2.2 - backports.tarfile: 1.2.0 - beautifulsoup4: 4.12.3 - black: 24.8.0 - certifi: 2024.7.4 - charset-normalizer: 3.3.2 - click: 8.1.7 - comm: 0.2.2 - datasets: 2.20.0 - debugpy: 1.8.5 - decorator: 5.1.1 - dill: 0.3.8 - exceptiongroup: 1.2.2 - executing: 2.0.1 - filelock: 3.15.4 - frozenlist: 1.4.1 - fsspec: 2024.5.0 - gdown: 5.2.0 - grpcio: 1.65.4 - huggingface-hub: 0.24.5 - idna: 3.7 - importlib-metadata: 8.2.0 - importlib-resources: 6.4.0 - inflect: 7.3.1 - ipykernel: 6.29.5 - ipython: 8.26.0 - isort: 5.13.2 - jaraco.context: 5.3.0 - jaraco.functools: 4.0.1 - jaraco.text: 3.12.1 - jedi: 0.19.1 - jinja2: 3.1.4 - jupyter-client: 8.6.2 - jupyter-core: 5.7.2 - lightning-utilities: 0.11.6 - markdown: 3.6 - markupsafe: 2.1.5 - matplotlib-inline: 0.1.7 - more-itertools: 10.3.0 - mpmath: 1.3.0 - multidict: 6.0.5 - multiprocess: 0.70.16 - mypy-extensions: 1.0.0 - nest-asyncio: 1.6.0 - networkx: 3.3 - numpy: 2.0.1 - nvidia-cublas-cu12: 12.1.3.1 - nvidia-cuda-cupti-cu12: 12.1.105 - nvidia-cuda-nvrtc-cu12: 12.1.105 - nvidia-cuda-runtime-cu12: 12.1.105 - nvidia-cudnn-cu12: 9.1.0.70 - nvidia-cufft-cu12: 11.0.2.54 - nvidia-curand-cu12: 10.3.2.106 - nvidia-cusolver-cu12: 11.4.5.107 - nvidia-cusparse-cu12: 12.1.0.106 - nvidia-nccl-cu12: 2.20.5 - nvidia-nvjitlink-cu12: 12.6.20 - nvidia-nvtx-cu12: 12.1.105 - ordered-set: 4.1.0 - packaging: 24.1 - pandas: 2.2.2 - parso: 0.8.4 - pathspec: 0.12.1 - pexpect: 4.9.0 - pickleshare: 0.7.5 - pillow: 10.4.0 - pip: 24.2 - platformdirs: 4.2.2 - prompt-toolkit: 3.0.47 - protobuf: 4.25.4 - psutil: 6.0.0 - ptyprocess: 0.7.0 - pure-eval: 0.2.3 - pyarrow: 17.0.0 - pyarrow-hotfix: 0.6 - pygments: 2.18.0 - pynvml: 11.5.3 - pysocks: 1.7.1 - python-dateutil: 2.9.0 - pytorch-lightning: 2.3.3 - pytz: 2024.1 - pyyaml: 6.0.1 - pyzmq: 26.1.0 - regex: 2024.7.24 - requests: 2.32.3 - safetensors: 0.4.4 - setuptools: 72.1.0 - six: 1.16.0 - soupsieve: 2.5 - stack-data: 0.6.2 - sympy: 1.13.1 - tensorboard: 2.17.0 - tensorboard-data-server: 0.7.2 - tiktoken: 0.7.0 - tokenizers: 0.19.1 - tomli: 2.0.1 - torch: 2.4.0 - torchmetrics: 1.4.1 - torchvision: 0.19.0 - tornado: 6.4.1 - tqdm: 4.66.5 - traitlets: 5.14.3 - transformers: 4.44.0 - triton: 3.0.0 - typeguard: 4.3.0 - typing-extensions: 4.12.2 - tzdata: 2024.1 - urllib3: 2.2.2 - wcwidth: 0.2.13 - werkzeug: 3.0.3 - wheel: 0.44.0 - xxhash: 3.4.1 - yarl: 1.9.4 - zipp: 3.19.2 * System: - OS: Linux - architecture: - 64bit - ELF - processor: x86_64 - python: 3.12.4 - release: 3.10.0-1160.71.1.el7.x86_64 - version: #1 SMP Tue Jun 28 15:37:28 UTC 2022 ```More info
No response