Open hrukalive opened 1 year ago
I think it is actually the moment when validation happens drift. The checkpoint saving is just a side effect.
Validation check tracks training batches instead of training steps. According to the documentation,
An
int
value can only be higher than the number of training batches whencheck_val_every_n_epoch=None
, which validates after everyN
training batches across epochs or during iteration-based training.
However, training batches does not always equal to training steps (global steps).
Training step is (total_batch_idx // accumulate_grad_batches) + (accumulates_on_final_batch * epoch_trained)
. The accumulates_on_final_batch
is where the draft happens.
I think it would make sense to validate after N training steps instead of training batches. Other module such as Logger and Model Checkpoint use global steps to track training steps too.
I propose we can change from
to
elif self.trainer.val_check_batch != float("inf"):
# if `check_val_every_n_epoch is` None`, run a validation loop every n training steps
# else condition it based on the batch_idx of the current epoch
next_iteration = self.global_step if self.trainer.check_val_every_n_epoch is None else self.batch_idx + 1
is_val_check_batch = next_iteration % self.trainer.val_check_batch == 0
Is there a plan to add step-based validation checks in Lightning?
Until Lightning adds official support for this, any recommendations on how we can override the default Lightning behavior and use number of steps instead of batches to trigger the validation?
Is there a plan to add step-based validation checks in Lightning?
Until Lightning adds official support for this, any recommendations on how we can override the default Lightning behavior and use number of steps instead of batches to trigger the validation?
Right now, for myself, I have to discard the last batch to make steps multiples of accum grad.
Bug description
First of all, my task relies on step count instead of epochs. So I am doing validation checks by steps and saving checkpoints after that. However, as I turned gradient accumulation on, and the batch count is not divisible, I encountered weird drifts for the actual step when the validation is performed, and thus the checkpointing.
In the example below, I override the
_save_checkpoint
function to monitor the actual file name and it turns out to be drifting. My general setting isval_check_interval=accumulation*5
to make it validate every 5 effective optimizer steps,accumulation=3
and#batches=67
so there is one batch leftover.How to reproduce the bug
Error messages and logs
Environment
Current environment
``` * CUDA: - GPU: - NVIDIA RTX A5000 - NVIDIA RTX A5000 - NVIDIA RTX A5000 - NVIDIA RTX A5000 - available: True - version: 11.7 * Lightning: - lightning: 2.0.0 - lightning-cloud: 0.5.32 - lightning-lite: 1.8.6 - lightning-utilities: 0.8.0 - pytorch-lightning: 2.0.0 - torch: 1.13.1 - torchaudio: 0.13.1 - torchcrepe: 0.0.17 - torchmetrics: 0.11.4 - torchvision: 0.14.1 * Packages: - absl-py: 1.3.0 - aiobotocore: 2.4.2 - aiohttp: 3.8.4 - aioitertools: 0.11.0 - aiosignal: 1.3.1 - altgraph: 0.17.3 - anyio: 3.6.2 - appdirs: 1.4.4 - arrow: 1.2.3 - async-timeout: 4.0.2 - attrs: 22.2.0 - audioread: 3.0.0 - backcall: 0.2.0 - beautifulsoup4: 4.12.0 - blessed: 1.20.0 - blinker: 1.4 - botocore: 1.27.59 - brotlipy: 0.7.0 - cachetools: 5.3.0 - certifi: 2022.12.7 - cffi: 1.15.1 - charset-normalizer: 2.0.4 - click: 8.1.3 - contourpy: 1.0.7 - croniter: 1.3.8 - cryptography: 39.0.1 - cycler: 0.11.0 - dateutils: 0.6.12 - decorator: 5.1.1 - deepdiff: 6.3.0 - distance: 0.1.3 - dnspython: 2.3.0 - einops: 0.6.0 - email-validator: 1.3.1 - et-xmlfile: 1.0.1 - fastapi: 0.88.0 - fire: 0.5.0 - flit-core: 3.8.0 - fonttools: 4.39.2 - frozenlist: 1.3.3 - fsspec: 2023.3.0 - future: 0.18.2 - g2p-en: 2.1.0 - g2pm: 0.1.2.5 - google-auth: 2.16.3 - google-auth-oauthlib: 0.4.6 - grpcio: 1.51.3 - h11: 0.14.0 - h5py: 3.7.0 - httpcore: 0.16.3 - httptools: 0.5.0 - httpx: 0.23.3 - idna: 3.4 - imageio: 2.23.0 - importlib-metadata: 6.1.0 - inflect: 6.0.2 - inquirer: 3.1.3 - itsdangerous: 2.1.2 - jinja2: 3.1.2 - jmespath: 1.0.1 - joblib: 1.2.0 - kiwisolver: 1.4.4 - librosa: 0.9.1 - lightning: 2.0.0 - lightning-cloud: 0.5.32 - lightning-lite: 1.8.6 - lightning-utilities: 0.8.0 - llvmlite: 0.39.1 - markdown: 3.4.3 - markdown-it-py: 2.2.0 - markupsafe: 2.1.2 - matplotlib: 3.6.2 - mdurl: 0.1.2 - mkl-fft: 1.3.1 - mkl-random: 1.2.2 - mkl-service: 2.4.0 - multidict: 6.0.4 - networkx: 3.0 - nltk: 3.8.1 - numba: 0.56.4 - numpy: 1.23.5 - oauthlib: 3.2.2 - ordered-set: 4.1.0 - orjson: 3.8.8 - packaging: 23.0 - pillow: 9.4.0 - pip: 23.0.1 - platformdirs: 3.1.1 - pooch: 1.7.0 - praat-parselmouth: 0.4.3 - protobuf: 3.13.0 - psutil: 5.9.4 - pyasn1: 0.4.8 - pyasn1-modules: 0.2.8 - pycparser: 2.21 - pycwt: 0.3.0a22 - pydantic: 1.10.7 - pygments: 2.14.0 - pyjwt: 2.6.0 - pyloudnorm: 0.1.0 - pyopenssl: 23.0.0 - pyparsing: 3.0.9 - pypinyin: 0.39.0 - pysocks: 1.7.1 - python-dateutil: 2.8.2 - python-dotenv: 1.0.0 - python-editor: 1.0.4 - python-levenshtein: 0.12.2 - python-multipart: 0.0.6 - pytorch-lightning: 2.0.0 - pytz: 2022.7.1 - pywavelets: 1.4.1 - pyyaml: 6.0 - readchar: 4.0.5 - regex: 2023.3.23 - requests: 2.28.1 - requests-oauthlib: 1.3.1 - resampy: 0.4.2 - resemblyzer: 0.1.1.dev0 - rfc3986: 1.5.0 - rich: 13.3.2 - rsa: 4.9 - s3fs: 2023.3.0 - scikit-image: 0.19.3 - scikit-learn: 1.2.2 - scipy: 1.9.3 - setuptools: 65.6.3 - six: 1.16.0 - snakeviz: 2.1.1 - sniffio: 1.3.0 - soundfile: 0.12.1 - soupsieve: 2.4 - starlette: 0.22.0 - starsessions: 1.3.0 - tensorboard: 2.11.0 - tensorboard-data-server: 0.6.1 - tensorboard-plugin-wit: 1.8.1 - tensorboardx: 2.6 - termcolor: 2.2.0 - threadpoolctl: 3.1.0 - tifffile: 2023.3.21 - torch: 1.13.1 - torchaudio: 0.13.1 - torchcrepe: 0.0.17 - torchmetrics: 0.11.4 - torchvision: 0.14.1 - tornado: 6.2 - tqdm: 4.65.0 - traitlets: 5.9.0 - typing: 3.7.4.3 - typing-extensions: 4.4.0 - ujson: 5.7.0 - urllib3: 1.26.14 - uvicorn: 0.21.1 - uvloop: 0.17.0 - watchfiles: 0.18.1 - wcwidth: 0.2.6 - webrtcvad: 2.0.10 - websocket-client: 1.5.1 - websockets: 10.4 - werkzeug: 2.2.3 - wheel: 0.38.4 - wrapt: 1.15.0 - yarl: 1.8.2 - zipp: 3.15.0 * System: - OS: Linux - architecture: - 64bit - ELF - processor: x86_64 - python: 3.9.16 - version: #153-Ubuntu SMP Thu Nov 24 15:56:58 UTC 2022 ```More info
Other than this phenomenon, I have two more questions
val_check_interval
tied to the number of batches rather thanglobal_step
?cc @carmocca @justusschock