Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.54k stars 3.39k forks source link

Error when disabling an optimizer with native AMP turned on #20116

Open schopra8 opened 4 months ago

schopra8 commented 4 months ago

Bug description

I'm using 2 optimizers and trying to train with AMP (FP16). I can take steps with my first optimizer. When I take my first step with the second optimizer I get the following error:

  File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/amp/grad_scaler.py", line 450, in step
    len(optimizer_state["found_inf_per_device"]) > 0
AssertionError: No inf checks were recorded for this optimizer.

I can train this correctly in FP32 -- so it seems to be an issue with AMP.

What version are you seeing the problem on?

version 2.3.3

How to reproduce the bug

def training_step(self, batch: Dict, batch_idx: int):
        """
        We have 2 sets of optimizers.
        Every N batches (self.n_batches_per_optimizer), we make an optimizer update and
        switch the optimizer to update.

        If self.n_batches_per_optimizer = 1, then we make updates every batch and alternate optimizers
        every batch.

        If self.n_batches_per_optimizer > 1, then we're doing gradient accumulartion, where we are making
        updates evern n_batches_per_optimizer batches and alternating optimizers every n_batches_per_optimizer
        batches.
        """
        opts = self.optimizers()
        current_cycle = (batch_idx // self.n_batches_per_optimizer) % len(opts)
        opt = opts[current_cycle]
        opt.zero_grad()

        if current_cycle == 0:
            compute_model_1_loss = True
        elif current_cycle == 1:
            compute_model_1_loss = False
        else:
            raise NotImplementedError(f"Unknown optimizer {current_cycle}")

        with opt.toggle_model():
            loss = self.inner_training_step(batch=batch, compute_model_1_loss=compute_model_1_loss)
            self.manual_backward(loss=loss)

            # Perform the optimization step every accumulate_grad_batches steps
            if (batch_idx + 1) % self.n_batches_per_optimizer == 0:
                if not compute_model_1_loss:
                    print("About to take compute model 2 loss ...")
                opt.step()
                opt.zero_grad()     

Error messages and logs

Traceback (most recent call last):
  File "/home/sahil/train.py", line 82, in <module>
    main(config)
  File "/home/sahil/train.py", line 62, in main
    trainer.fit(model, datamodule=data_module, ckpt_path=ckpt)
  File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 543, in fit
    call._call_and_handle_interrupt(
  File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 44, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 579, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 986, in _run
    results = self._run_stage()
  File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1030, in _run_stage
    self.fit_loop.run()
  File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py", line 205, in run
    self.advance()
  File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py", line 363, in advance
    self.epoch_loop.run(self._data_fetcher)
  File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 140, in run
    self.advance(data_fetcher)
  File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 252, in advance
    batch_output = self.manual_optimization.run(kwargs)
  File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/manual.py", line 94, in run
    self.advance(kwargs)
  File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/manual.py", line 114, in advance
    training_step_output = call._call_strategy_hook(trainer, "training_step", *kwargs.values())
  File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 311, in _call_strategy_hook
    output = fn(*args, **kwargs)
  File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py", line 390, in training_step
    return self.lightning_module.training_step(*args, **kwargs)
  File "/home/sahil/model/model.py", line 169, in training_step
    opt.step()
  File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/core/optimizer.py", line 153, in step
    step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs)
  File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py", line 238, in optimizer_step
    return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs)
  File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/plugins/precision/amp.py", line 93, in optimizer_step
    step_output = self.scaler.step(optimizer, **kwargs)  # type: ignore[arg-type]
  File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/amp/grad_scaler.py", line 450, in step
    len(optimizer_state["found_inf_per_device"]) > 0
AssertionError: No inf checks were recorded for this optimizer.

Environment

Current environment ``` * CUDA: - GPU: - NVIDIA A100-SXM4-80GB - available: True - version: 12.1 * Lightning: - lightning-utilities: 0.11.5 - pytorch-lightning: 2.3.3 - torch: 2.3.1 - torchmetrics: 1.4.0.post0 - torchvision: 0.18.1 * Packages: - aiohttp: 3.9.5 - aiosignal: 1.3.1 - annotated-types: 0.7.0 - antlr4-python3-runtime: 4.9.3 - anyio: 4.4.0 - argon2-cffi: 23.1.0 - argon2-cffi-bindings: 21.2.0 - arrow: 1.3.0 - asttokens: 2.4.1 - async-lru: 2.0.4 - async-timeout: 4.0.3 - attrs: 23.2.0 - autocommand: 2.2.2 - babel: 2.15.0 - backports.tarfile: 1.2.0 - beautifulsoup4: 4.12.3 - bitsandbytes: 0.43.1 - bleach: 6.1.0 - boto3: 1.34.144 - botocore: 1.34.144 - braceexpand: 0.1.7 - certifi: 2024.7.4 - nvidia-curand-cu12: 10.3.2.106 - nvidia-cusolver-cu12: 11.4.5.107 - nvidia-cusparse-cu12: 12.1.0.106 - nvidia-nccl-cu12: 2.20.5 - nvidia-nvjitlink-cu12: 12.5.82 - nvidia-nvtx-cu12: 12.1.105 - omegaconf: 2.3.0 - opencv-python: 4.10.0.84 - ordered-set: 4.1.0 - overrides: 7.7.0 - packaging: 24.1 - pandocfilters: 1.5.1 - parso: 0.8.4 - pexpect: 4.9.0 - pillow: 10.4.0 - pip: 24.1 - platformdirs: 4.2.2 - pre-commit: 3.7.1 - proglog: 0.1.10 - prometheus-client: 0.20.0 - prompt-toolkit: 3.0.47 - protobuf: 5.27.2 - psutil: 6.0.0 - ptyprocess: 0.7.0 - pure-eval: 0.2.2 - pycparser: 2.22 - pydantic: 2.8.2 - pydantic-core: 2.20.1 - pydantic-settings: 2.3.4 - pygments: 2.18.0 - python-dateutil: 2.9.0.post0 - python-dotenv: 1.0.1 - python-json-logger: 2.0.7 - pytorch-lightning: 2.3.3 - pyyaml: 6.0.1 - pyzmq: 26.0.3 - referencing: 0.35.1 - requests: 2.32.3 - rfc3339-validator: 0.1.4 - rfc3986-validator: 0.1.1 - rpds-py: 0.19.0 - s3transfer: 0.10.2 - send2trash: 1.8.3 - sentry-sdk: 2.10.0 - setproctitle: 1.3.3 - setuptools: 71.0.2 - six: 1.16.0 - smmap: 5.0.1 - sniffio: 1.3.1 - soupsieve: 2.5 - stack-data: 0.6.3 - sympy: 1.13.0 - terminado: 0.18.1 - tinycss2: 1.3.0 - tomli: 2.0.1 - torch: 2.3.1 - torchmetrics: 1.4.0.post0 - torchvision: 0.18.1 - tornado: 6.4.1 - tqdm: 4.66.4 - traitlets: 5.14.3 - triton: 2.3.1 - typeguard: 4.3.0 - types-python-dateutil: 2.9.0.20240316 - typing-extensions: 4.12.2 - uri-template: 1.3.0 - urllib3: 2.2.2 - virtualenv: 20.26.3 - wandb: 0.17.4 - wcwidth: 0.2.13 - webcolors: 24.6.0 - webdataset: 0.2.86 - webencodings: 0.5.1 - websocket-client: 1.8.0 - wheel: 0.43.0 - yarl: 1.9.4 - zipp: 3.19.2 * System: - OS: Linux - architecture: - 64bit - ELF - processor: - python: 3.10.14 - release: 5.10.0-31-cloud-amd64 - version: #1 SMP Debian 5.10.221-1 (2024-07-14) ```

More info

No response

schopra8 commented 4 months ago

There were similar issues reported a few years back -- https://github.com/Lightning-AI/pytorch-lightning/issues/7792

And they were seem to be solved -- https://github.com/Lightning-AI/pytorch-lightning/pull/7975

So not sure if the bug was re-introduced in subsequent years OR if I'm missing something in my example code.