Model does not update its weights #20215

kopalja commented 3 weeks ago

Bug description

Hi, I am using PyTorch lightning to implement some new optimization strategies using automatic_optimization=False. For certain setting my optimization strategy (using automatic_optimization=False) should yield the same results as using standard optimization process (automatic_optimization=True). However I could not make it work. My optimization process was returning slightly different results as using default optimization process. After a while I figured out that PyTorch lightning sometimes does not update the model weights when using the default automatic_optimization=True. I have put together minimal example in which model weights won't get updated on step 5. Model weights also won't get updated when using different hyper-parameters (e.g., batch-size, lr), only at different training step.

Am I missing something or does this look like a bug. Thanks!

How to reproduce the bug

import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
from import DataLoader
from torchvision import datasets, transforms

class CNN(nn.Module):
    def __init__(self):
        self.convs = nn.ModuleList(
                nn.Conv2d(1, 64, 3, 1),
                nn.Conv2d(64, 64, 3, 1),
                nn.Conv2d(64, 128, 3, 1),
        self.fc1 = nn.Linear(128, 10)

    def forward(self, x, target):
        for conv in self.convs:
            x = conv(x)
            x = F.relu(x)
            x = F.max_pool2d(x, 2)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        logits = F.log_softmax(x, dim=1)
        return F.nll_loss(logits, target)

class MRELoop(pl.LightningModule):
    def __init__(self):
        super(MRELoop, self).__init__()
        self.model = CNN()
        self.dataset = datasets.MNIST(root=".mnist_data", download=True, transform=transforms.ToTensor())
        self.previous_params = None

    def training_step(self, batch, batch_idx):
        # Check whether new model weights differs from previous ones
        params =[param.view(-1) for param in self.model.parameters()])
        if self.previous_params is not None:
            num_different_values = (self.previous_params != params).sum().item()
            self.trainer.should_stop = num_different_values == 0
            num_different_values = None

        self.previous_params = params
        loss = self.model.forward(*batch)
            f"step {batch_idx} | diff weights: {num_different_values} | all weights: {params.numel()} | weights mean: {torch.mean(params)} | loss: {loss.item()}"
        return loss

    def configure_optimizers(self):
        # Bug occurs also with different lr only at differnt training step
        return torch.optim.AdamW(self.parameters(), lr=2e-3)
        # return torch.optim.SGD(self.parameters(), lr=9e-4) # Also with SGD

    def train_dataloader(self):
        return DataLoader(self.dataset)

if __name__ == "__main__":
    pl_trainer = pl.Trainer(
        precision="16-mixed",  # So far bug has occured only with 16-mixed

/home/kopal/miniconda3/envs/overshoot/lib/python3.12/site-packages/lightning_fabric/plugins/environments/ The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python ...
Using 16bit Automatic Mixed Precision (AMP)
/home/kopal/miniconda3/envs/overshoot/lib/python3.12/site-packages/pytorch_lightning/plugins/precision/ `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/kopal/miniconda3/envs/overshoot/lib/python3.12/site-packages/pytorch_lightning/loops/ `max_epochs` was not set. Setting it to 1000 epochs. To train without an epoch limit, set `max_epochs=-1`.

  | Name  | Type | Params | Mode 
0 | model | CNN  | 112 K  | train
112 K     Trainable params
0         Non-trainable params
112 K     Total params
0.451     Total estimated model params size (MB)
/home/kopal/miniconda3/envs/overshoot/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/ The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=63` in the `DataLoader` to improve performance.
step 0 | diff weights: None | all weights: 112714 | weights mean: 1.6999114450300112e-05 | loss: 2.334902763366699
step 1 | diff weights: 112714 | all weights: 112714 | weights mean: 3.690078665385954e-05 | loss: 2.32588529586792
step 2 | diff weights: 112714 | all weights: 112714 | weights mean: -0.00010425636719446629 | loss: 2.621901512145996
step 3 | diff weights: 112714 | all weights: 112714 | weights mean: -0.00030326732667163014 | loss: 2.4029626846313477
step 4 | diff weights: 112714 | all weights: 112714 | weights mean: -0.0005236949073150754 | loss: 2.657553195953369
step 5 | diff weights: 0 | all weights: 112714 | weights mean: -0.0005236949073150754 | loss: 2.5822641849517822


Current environment ``` * CUDA: - GPU: - NVIDIA A100-PCIE-40GB - available: True - version: 12.1 * Lightning: - lightning-utilities: 0.11.6 - pytorch-lightning: 2.3.3 - torch: 2.4.0 - torchmetrics: 1.4.1 - torchvision: 0.19.0 * Packages: - absl-py: 2.1.0 - aiohappyeyeballs: 2.3.4 - aiohttp: 3.10.1 - aiosignal: 1.3.1 - asttokens: 2.4.1 - attrs: 24.1.0 - autocommand: 2.2.2 - backports.tarfile: 1.2.0 - beautifulsoup4: 4.12.3 - black: 24.8.0 - certifi: 2024.7.4 - charset-normalizer: 3.3.2 - click: 8.1.7 - comm: 0.2.2 - datasets: 2.20.0 - debugpy: 1.8.5 - decorator: 5.1.1 - dill: 0.3.8 - exceptiongroup: 1.2.2 - executing: 2.0.1 - filelock: 3.15.4 - frozenlist: 1.4.1 - fsspec: 2024.5.0 - gdown: 5.2.0 - grpcio: 1.65.4 - huggingface-hub: 0.24.5 - idna: 3.7 - importlib-metadata: 8.2.0 - importlib-resources: 6.4.0 - inflect: 7.3.1 - ipykernel: 6.29.5 - ipython: 8.26.0 - isort: 5.13.2 - jaraco.context: 5.3.0 - jaraco.functools: 4.0.1 - jaraco.text: 3.12.1 - jedi: 0.19.1 - jinja2: 3.1.4 - jupyter-client: 8.6.2 - jupyter-core: 5.7.2 - lightning-utilities: 0.11.6 - markdown: 3.6 - markupsafe: 2.1.5 - matplotlib-inline: 0.1.7 - more-itertools: 10.3.0 - mpmath: 1.3.0 - multidict: 6.0.5 - multiprocess: 0.70.16 - mypy-extensions: 1.0.0 - nest-asyncio: 1.6.0 - networkx: 3.3 - numpy: 2.0.1 - nvidia-cublas-cu12: - nvidia-cuda-cupti-cu12: 12.1.105 - nvidia-cuda-nvrtc-cu12: 12.1.105 - nvidia-cuda-runtime-cu12: 12.1.105 - nvidia-cudnn-cu12: - nvidia-cufft-cu12: - nvidia-curand-cu12: - nvidia-cusolver-cu12: - nvidia-cusparse-cu12: - nvidia-nccl-cu12: 2.20.5 - nvidia-nvjitlink-cu12: 12.6.20 - nvidia-nvtx-cu12: 12.1.105 - ordered-set: 4.1.0 - packaging: 24.1 - pandas: 2.2.2 - parso: 0.8.4 - pathspec: 0.12.1 - pexpect: 4.9.0 - pickleshare: 0.7.5 - pillow: 10.4.0 - pip: 24.2 - platformdirs: 4.2.2 - prompt-toolkit: 3.0.47 - protobuf: 4.25.4 - psutil: 6.0.0 - ptyprocess: 0.7.0 - pure-eval: 0.2.3 - pyarrow: 17.0.0 - pyarrow-hotfix: 0.6 - pygments: 2.18.0 - pynvml: 11.5.3 - pysocks: 1.7.1 - python-dateutil: 2.9.0 - pytorch-lightning: 2.3.3 - pytz: 2024.1 - pyyaml: 6.0.1 - pyzmq: 26.1.0 - regex: 2024.7.24 - requests: 2.32.3 - safetensors: 0.4.4 - setuptools: 72.1.0 - six: 1.16.0 - soupsieve: 2.5 - stack-data: 0.6.2 - sympy: 1.13.1 - tensorboard: 2.17.0 - tensorboard-data-server: 0.7.2 - tiktoken: 0.7.0 - tokenizers: 0.19.1 - tomli: 2.0.1 - torch: 2.4.0 - torchmetrics: 1.4.1 - torchvision: 0.19.0 - tornado: 6.4.1 - tqdm: 4.66.5 - traitlets: 5.14.3 - transformers: 4.44.0 - triton: 3.0.0 - typeguard: 4.3.0 - typing-extensions: 4.12.2 - tzdata: 2024.1 - urllib3: 2.2.2 - wcwidth: 0.2.13 - werkzeug: 3.0.3 - wheel: 0.44.0 - xxhash: 3.4.1 - yarl: 1.9.4 - zipp: 3.19.2 * System: - OS: Linux - architecture: - 64bit - ELF - processor: x86_64 - python: 3.12.4 - release: 3.10.0-1160.71.1.el7.x86_64 - version: #1 SMP Tue Jun 28 15:37:28 UTC 2022 ```

richbai90 commented 3 weeks ago

Thanks for reporting the issue. Setting precision to '32-true' fixes the problem for me.

kopalja commented 3 weeks ago

Yes but that is not really the solution. In addition the problem might be still present and manifest itself at different training step.

richbai90 commented 2 weeks ago

Agreed it's not a fix, but it saved me from having to rewrite my implementation or tell my PI that we had to wait for a bug to be fixed before we could finish our paper.