Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.31k stars 3.38k forks source link

TrainingEpochLoop._should_check_val_fx discrepancy between continued run <> restore from ckpt #14579

Open Anner-deJong opened 2 years ago

Anner-deJong commented 2 years ago

🐛 Bug

Found a discrepancy between a continued run after checkpointing, and restoring from checkpoint

Observation:

training_batch / val_loop ordering upon checkpoint restoration not the same as original run after checkpoint saving.

There are still the same amount of train steps, but the validation loops are interleaved at a single step later, which can cause the restored run to end up with one less validation loop (see colab)

Assumption / expectation:

Zero difference between a training run after a checkpoint and a run continued from said checkpoint

Investigation so far:

Im new to some of this lightning code, but IIUC:

Key:

TrainingEpochLoop's self.batch_progress.increment_completed() is called after "on_train_batch_end" hooks, the latter kicking off checkpoint saving.

  1. upon restoring, the TrainingEpochLoop.batch_progress.current.reset_on_restart() will reset the ready back to completed
  2. yet the global_step, which refers to TrainingEpochLoop.batch_loop.optimizer_loop.optim_progress.optimizer.step.total.completed, has been increment_completed() (called within TrainingEpochLoop.batch_loop.run) and thus upon restoring, ..optimizer.step.total.ready is set to an up to date optimizer.step.total.completed, out of sync with the above
  3. [simplification] in "val_check_interval mode", validation is triggered when TrainingEpochLoop.batch_progress.current.ready % val_check_interval == 0 (through TrainingEpochLoop.on_advance_end -> TrainingEpochLoop._should_check_val_fx
  4. combining the above three, the same batch_progress.current ready/completed counter for the continued and restored runs, end up aligned with different global_steps, and hence validation triggers at different global_steps

Another observation:

The following if statement seems to allow for a zero-difference restart, except that just like 4. above, _should_check_val_fx wouldnt trigger where in the original run on the checkpointing step it did (although there called in on_advance_end). Not sure if the original intention of this snippet included the current scope

class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
    ...
    def advance(self, data_fetcher: AbstractDataFetcher) -> None:  # type: ignore[override]
        ...
        if self.restarting and self._should_check_val_fx(self.batch_idx, self.batch_progress.is_last_batch):
            # skip training and run validation in `on_advance_end`
            return

PR's relevant to this line:

Potential impact:

Assuming not too worrisome for the more default Lightning use cases:

However, in theory it can influence all of the following:

To Reproduce

customized google colab bug_report_model.ipynb with same observation on BoringModel

Expected behavior

Zero difference between a training run continued after a checkpoint and a run continued from said checkpoint

Environment

Note:

Details
  • CUDA:
    • GPU:
      • NVIDIA RTX A4000
      • NVIDIA RTX A4000
      • NVIDIA RTX A4000
      • NVIDIA RTX A4000
    • available: True
    • version: 11.0
  • Lightning:
    • efficientnet-pytorch: 0.7.1
    • pytorch-lightning: 1.6.4
    • torch: 1.11.0.post1103
    • torchmetrics: 0.7.0
    • torchvision: 0.12.0a1110.post1103
  • Packages:
    • absl-py: 0.15.0
    • adal: 1.2.7
    • adlfs: 2021.10.0
    • aiohttp: 3.7.4
    • applicationinsights: 0.11.10
    • argcomplete: 1.12.3
    • async-timeout: 3.0.1
    • attrdict: 2.0.0
    • attrs: 21.1.0
    • av: 8.0.3
    • azure-cli-core: 2.38.0
    • azure-cli-telemetry: 1.0.6
    • azure-common: 1.1.27
    • azure-core: 1.20.0
    • azure-datalake-store: 0.0.52
    • azure-identity: 1.10.0
    • azure-keyvault-secrets: 4.2.0
    • azure-mgmt-core: 1.2.2
    • azure-storage-blob: 12.11.0
    • backcall: 0.2.0
    • backoff: 1.10.0
    • bcrypt: 3.2.0
    • cachetools: 4.2.2
    • certifi: 2020.12.5
    • cffi: 1.14.5
    • chardet: 3.0.4
    • charset-normalizer: 2.0.12
    • click: 7.1.2
    • confluent-kafka: 1.7.0
    • cryptography: 3.4.8
    • cycler: 0.10.0
    • datadog: 0.44.0
    • decorator: 5.0.7
    • deepdiff: 5.5.0
    • deltalake: 0.5.8
    • docker-pycreds: 0.4.0
    • efficientnet-pytorch: 0.7.1
    • einops: 0.4.1
    • filelock: 3.7.1
    • fonttools: 4.37.1
    • frozendict: 2.3.2
    • fsspec: 2022.1.0
    • gitdb: 4.0.7
    • gitpython: 3.1.14
    • google-auth: 1.30.0
    • google-auth-oauthlib: 0.4.4
    • grpcio: 1.37.1
    • htmlmin: 0.1.12
    • humanfriendly: 10.0
    • idna: 2.10
    • imagehash: 4.2.1
    • inplace-abn: 1.1.0a1110.post1103
    • ipdb: 0.13.9
    • ipython: 7.23.1
    • isodate: 0.6.0
    • jedi: 0.18.0
    • jinja2: 3.1.2
    • jmespath: 0.10.0
    • joblib: 1.0.1
    • kafka-python: 2.0.2
    • kiwisolver: 1.3.1
    • knack: 0.9.0
    • markdown: 3.3.4
    • markupsafe: 2.0.1
    • matplotlib: 3.5.3
    • matplotlib-inline: 0.1.2
    • methodtools: 0.1.2
    • missingno: 0.5.0
    • msal: 1.16.0
    • msal-extensions: 0.3.0
    • msrest: 0.6.21
    • msrestazure: 0.6.4
    • multidict: 5.1.0
    • multimethod: 1.6
    • networkx: 2.5.1
    • numpy: 1.22.4
    • oauthlib: 3.1.0
    • opencv-python: 4.4.0.44
    • ordered-set: 4.0.2
    • packaging: 21.3
    • pandas: 1.4.3
    • pandas-profiling: 3.1.0
    • paramiko: 2.7.2
    • parso: 0.8.2
    • pathtools: 0.1.2
    • pexpect: 4.8.0
    • phik: 0.12.0
    • pickleshare: 0.7.5
    • pillow: 9.2.0
    • pip: 22.0.3
    • pkginfo: 1.7.0
    • polyline: 1.4.0
    • portalocker: 1.7.1
    • prometheus-client: 0.8.0
    • promise: 2.3
    • prompt-toolkit: 2.0.10
    • protobuf: 3.15.8
    • psutil: 5.9.1
    • psycopg2: 2.8.3
    • ptyprocess: 0.7.0
    • py: 1.10.0
    • py3nvml: 0.2.7
    • pyarrow: 9.0.0
    • pyasn1: 0.4.8
    • pyasn1-modules: 0.2.8
    • pycparser: 2.20
    • pydantic: 1.8.2
    • pydeprecate: 0.3.1
    • pygame: 2.1.2
    • pygments: 2.9.0
    • pyjwt: 1.7.1
    • pynacl: 1.4.0
    • pyntcloud: 0.1.6
    • pyopenssl: 20.0.1
    • pyparsing: 2.4.7
    • pyquaternion: 0.9.9
    • pysocks: 1.7.1
    • python-dateutil: 2.8.2
    • python-json-logger: 2.0.2
    • pytorch-lightning: 1.6.4
    • pytz: 2022.1
    • pywavelets: 1.1.1
    • pyyaml: 6.0
    • qrcode: 6.1
    • requests: 2.27.1
    • requests-oauthlib: 1.3.0
    • retry: 0.9.2
    • rsa: 4.7.2
    • runai: 0.3.0
    • scipy: 1.6.2
    • seaborn: 0.11.2
    • semver: 2.13.0
    • sentry-sdk: 1.9.4
    • setproctitle: 1.2.2
    • setuptools: 59.5.0
    • shapely: 1.8.0
    • shortuuid: 1.0.1
    • simplejpeg: 1.4.1
    • six: 1.16.0
    • slackclient: 2.9.4
    • smmap: 4.0.0
    • sqlalchemy: 1.3.24
    • tabulate: 0.8.9
    • tangled-up-in-unicode: 0.1.0
    • tensorboard: 2.6.0
    • tensorboard-data-server: 0.6.1
    • tensorboard-plugin-wit: 1.8.0
    • timm: 0.4.5
    • toml: 0.10.2
    • torch: 1.11.0.post1103
    • torchmetrics: 0.7.0
    • torchvision: 0.12.0a1110.post1103
    • tqdm: 4.60.0
    • traitlets: 5.3.0
    • transforms3d: 0.3.1
    • typing-extensions: 4.1.1
    • urllib3: 1.26.11
    • visions: 0.7.4
    • wandb: 0.12.14
    • wcwidth: 0.2.5
    • werkzeug: 1.0.1
    • wheel: 0.36.2
    • wirerope: 0.3.1
    • wrapt: 1.14.1
    • xmltodict: 0.12.0
    • xxhash: 1.4.1
    • yarl: 1.6.3
  • System:
    • OS: Linux
    • architecture:
      • 64bit
      • ELF
    • processor: x86_64
    • python: 3.8.12
    • version: /#138~18.04.1-Ubuntu SMP Fri Jun 24 14:14:03 UTC 2022

Additional context

cc @awaelchli @ananthsub @ninginthecloud @rohitgr7 @otaj @carmocca @justusschock

Anner-deJong commented 2 years ago

I'd be happy to help fix if agreed this is a bug, but like the opinion + take on a solution from folks more involved with pytorch lightning. Some simple ideas, but none of them great:

krshrimali commented 2 years ago

Hi, @Anner-deJong - Thank you for creating this issue, and helping with all the context around this. Probably @awaelchli or @carmocca can help you on this one.

lantiga commented 3 days ago

Getting back to this issue, as it came up recently in a different context. This is definitely a behavior that needs fixing.

My take would be to not add an extra hook, but invoke increment_completed prior to on_train_batch_end. Which one comes before is very debatable from a definition standpoint. Having increments already set up correctly when I'm in my on_train_batch_end hook is a fair expectation IMO.

The sequence would go from:

self.batch_progress.increment_ready()
on_train_batch_start
self.batch_progress.increment_started()
...
self.batch_progress.increment_processed()
...
on_train_batch_end
self.batch_progress.increment_completed()

to

self.batch_progress.increment_ready()
on_train_batch_start
self.batch_progress.increment_started()
...
self.batch_progress.increment_processed()
...
self.batch_progress.increment_completed()
on_train_batch_end

Not sure what do think about the started part, but that's for another time.

lantiga commented 3 days ago

On further thought, I'll take a more conservative approach since the implications are pretty wide.