Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
27.91k stars 3.34k forks source link

PyTorch Lightning produces different loss when resuming from ckpt vs training without interruption #18098

Closed dinhanhx closed 1 year ago

dinhanhx commented 1 year ago

Bug description

As the title said, the loss values are different when resuming from ckpt vs training without interruption.

First, run the code below without interruption. Second, run the code again, wait till a certain step, kill it Finally, run the code again with the ckpt

What version are you seeing the problem on?

v2.0

How to reproduce the bug

from typing import Any, Union

import lightning.pytorch as pl
import torch
from lightning.pytorch.callbacks import (
    LearningRateMonitor,
    ModelCheckpoint,
    RichProgressBar,
)
from lightning.pytorch.demos import Transformer, WikiText2
from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger
from lightning.pytorch.utilities.types import STEP_OUTPUT
from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers.optimization import get_cosine_schedule_with_warmup

class BoringTransformer(pl.LightningModule):
    def __init__(
        self, vocab_size: int, learning_rate: float = 5e-5, warmup_ratio: float = 0.1
    ) -> None:
        super().__init__()
        self.learning_rate = learning_rate
        self.warmup_ratio = warmup_ratio
        self.transformer = Transformer(vocab_size=vocab_size)

    def training_step(self, batch) -> STEP_OUTPUT:
        x, y = batch
        z = self.transformer(x, y)
        loss = torch.nn.functional.nll_loss(z, y.view(-1))
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self) -> Any:
        opt = AdamW(self.parameters(), self.learning_rate)
        opt_list = [opt]
        lrs = {
            "scheduler": get_cosine_schedule_with_warmup(
                opt,
                self.trainer.estimated_stepping_batches * self.warmup_ratio,
                self.trainer.estimated_stepping_batches,
            ),
            "interval": "step",
            "frequency": 1,
        }
        lrs_list = [lrs]
        return opt_list, lrs_list

pl.seed_everything(1312)
dataset = WikiText2()
dataloader = DataLoader(dataset, batch_size=16)
transformer = BoringTransformer(dataset.vocab_size)
trainer = pl.Trainer(
    enable_checkpointing=True,
    default_root_dir="boring_logs",
    accelerator="gpu",
    devices=1,
    precision="16-mixed",
    logger=[CSVLogger("boring_logs")],
    callbacks=[
        RichProgressBar(),
        ModelCheckpoint(every_n_train_steps=16, save_last=True)
    ],
    max_epochs=2,
    accumulate_grad_batches=8,
    log_every_n_steps=16,
    enable_model_summary=False,
)

ckpt_path: Union[
    str, None
] = "boring_logs/lightning_logs/version_1/checkpoints/epoch=0-step=288.ckpt"
trainer.fit(transformer, dataloader, ckpt_path=ckpt_path)

Error messages and logs

image

Environment

Current environment * CUDA: - GPU: - Tesla P100-PCIE-16GB - available: True - version: 11.8 * Lightning: - lightning: 2.0.5 - lightning-cloud: 0.5.37 - lightning-utilities: 0.9.0 - pytorch-ignite: 0.4.12 - pytorch-lightning: 2.0.4 - torch: 2.0.0 - torchaudio: 2.0.1 - torchdata: 0.6.0 - torchinfo: 1.8.0 - torchmetrics: 1.0.0 - torchtext: 0.15.1 - torchvision: 0.15.1 * Packages: - absl-py: 1.4.0 - accelerate: 0.20.3 - access: 1.1.9 - affine: 2.4.0 - aiobotocore: 2.5.2 - aiofiles: 22.1.0 - aiohttp: 3.8.4 - aiohttp-cors: 0.7.0 - aioitertools: 0.11.0 - aiorwlock: 1.3.0 - aiosignal: 1.3.1 - aiosqlite: 0.19.0 - albumentations: 1.3.1 - alembic: 1.11.1 - altair: 5.0.1 - annoy: 1.17.3 - ansiwrap: 0.8.4 - anyio: 3.7.0 - apache-beam: 2.46.0 - aplus: 0.11.0 - appdirs: 1.4.4 - argon2-cffi: 21.3.0 - argon2-cffi-bindings: 21.2.0 - array-record: 0.4.0 - arrow: 1.2.3 - arviz: 0.12.1 - astroid: 2.15.6 - astropy: 5.3.1 - asttokens: 2.2.1 - astunparse: 1.6.3 - async-timeout: 4.0.2 - atpublic: 4.0 - attrs: 23.1.0 - audioread: 3.0.0 - autopep8: 2.0.2 - babel: 2.12.1 - backcall: 0.2.0 - backoff: 2.2.1 - backports.functools-lru-cache: 1.6.4 - bayesian-optimization: 1.4.3 - bayespy: 0.5.26 - beatrix-jupyterlab: 2023.621.222118 - beautifulsoup4: 4.12.2 - bidict: 0.22.1 - biopython: 1.81 - blake3: 0.2.1 - bleach: 6.0.0 - blessed: 1.20.0 - blinker: 1.6.2 - blis: 0.7.9 - blosc2: 2.0.0 - bokeh: 3.1.1 - boltons: 23.0.0 - boruta: 0.3 - boto3: 1.26.100 - botocore: 1.29.161 - bq-helper: 0.4.1 - bqplot: 0.12.39 - branca: 0.6.0 - brewer2mpl: 1.4.1 - brotlipy: 0.7.0 - cachetools: 5.3.0 - cartopy: 0.21.1 - catalogue: 2.0.8 - catalyst: 22.4 - catboost: 1.2 - category-encoders: 2.6.1 - certifi: 2023.5.7 - cesium: 0.12.1 - cffi: 1.15.1 - cftime: 1.6.2 - charset-normalizer: 3.1.0 - chex: 0.1.81 - cleverhans: 4.0.0 - click: 8.1.4 - click-plugins: 1.1.1 - cligj: 0.7.2 - cloud-tpu-client: 0.10 - cloud-tpu-profiler: 2.4.0 - cloudpickle: 2.2.1 - cmaes: 0.9.1 - cmdstanpy: 1.1.0 - cmudict: 1.0.13 - colorama: 0.4.6 - colorcet: 3.0.1 - colorful: 0.5.5 - colorlog: 6.7.0 - colorlover: 0.3.0 - comm: 0.1.3 - conda: 23.5.0 - conda-content-trust: 0+unknown - conda-package-handling: 2.0.2 - conda-package-streaming: 0.8.0 - confection: 0.1.0 - contextily: 1.3.0 - contourpy: 1.1.0 - convertdate: 2.4.0 - crcmod: 1.7 - croniter: 1.4.1 - cryptography: 41.0.1 - cubinlinker: 0.3.0 - cuda-python: 11.8.2 - cudf: 23.6.1 - cufflinks: 0.17.3 - cuml: 23.6.0 - cupy: 12.1.0 - cvxcanon: 0.1.2 - cycler: 0.11.0 - cymem: 2.0.7 - cysignals: 1.11.2 - cython: 0.29.35 - cytoolz: 0.12.0 - daal: 2023.1.1 - daal4py: 2023.1.1 - dacite: 1.8.1 - dask: 2023.7.0 - dask-cuda: 23.6.0 - dask-cudf: 23.6.1 - dataclasses-json: 0.5.9 - datasets: 2.1.0 - datashader: 0.15.1 - datashape: 0.5.2 - datatile: 1.0.3 - dateutils: 0.6.12 - db-dtypes: 1.1.1 - deap: 1.3.3 - debugpy: 1.6.7 - decorator: 5.1.1 - deepdiff: 6.3.1 - defusedxml: 0.7.1 - delorean: 1.0.0 - deprecat: 2.1.1 - deprecated: 1.2.14 - deprecation: 2.1.0 - descartes: 1.1.0 - dill: 0.3.6 - dipy: 1.7.0 - distlib: 0.3.6 - distributed: 2023.3.2.1 - dm-tree: 0.1.8 - docker: 6.1.3 - docker-pycreds: 0.4.0 - docopt: 0.6.2 - docstring-parser: 0.15 - docstring-to-markdown: 0.12 - docutils: 0.20.1 - earthengine-api: 0.1.358 - easydict: 1.10 - easyocr: 1.7.0 - ecos: 2.0.12 - eli5: 0.13.0 - emoji: 2.6.0 - en-core-web-lg: 3.5.0 - en-core-web-sm: 3.5.0 - entrypoints: 0.4 - ephem: 4.1.4 - esda: 2.4.3 - essentia: 2.1b6.dev1034 - et-xmlfile: 1.1.0 - etils: 1.3.0 - exceptiongroup: 1.1.1 - executing: 1.2.0 - explainable-ai-sdk: 1.3.3 - fastai: 2.7.12 - fastapi: 0.98.0 - fastavro: 1.7.4 - fastcore: 1.5.29 - fastdownload: 0.0.7 - fasteners: 0.18 - fastjsonschema: 2.17.1 - fastprogress: 1.0.3 - fastrlock: 0.8 - fasttext: 0.9.2 - fbpca: 1.0 - feather-format: 0.4.1 - featuretools: 1.26.0 - filelock: 3.12.2 - fiona: 1.9.4.post1 - fitter: 1.5.2 - flake8: 6.0.0 - flashtext: 2.7 - flask: 2.3.2 - flatbuffers: 23.5.26 - flax: 0.7.0 - flit-core: 3.9.0 - folium: 0.14.0 - fonttools: 4.40.0 - fqdn: 1.5.1 - frozendict: 2.3.8 - frozenlist: 1.3.3 - fsspec: 2023.6.0 - funcy: 2.0 - fury: 0.9.0 - future: 0.18.3 - fuzzywuzzy: 0.18.0 - gast: 0.4.0 - gatspy: 0.3 - gcsfs: 2023.6.0 - gensim: 4.3.1 - geographiclib: 2.0 - geohash: 1.0 - geojson: 3.0.1 - geopandas: 0.13.2 - geoplot: 0.5.1 - geopy: 2.3.0 - geoviews: 1.10.0 - ggplot: 0.11.5 - giddy: 2.3.4 - gitdb: 4.0.10 - gitpython: 3.1.31 - google-api-core: 2.11.1 - google-api-python-client: 2.92.0 - google-apitools: 0.5.31 - google-auth: 2.20.0 - google-auth-httplib2: 0.1.0 - google-auth-oauthlib: 1.0.0 - google-cloud-aiplatform: 0.6.0a1 - google-cloud-artifact-registry: 1.8.1 - google-cloud-automl: 1.0.1 - google-cloud-bigquery: 2.34.4 - google-cloud-bigtable: 1.7.3 - google-cloud-core: 2.3.2 - google-cloud-datastore: 2.16.1 - google-cloud-dlp: 3.12.1 - google-cloud-language: 2.10.1 - google-cloud-monitoring: 2.15.0 - google-cloud-pubsub: 2.17.1 - google-cloud-pubsublite: 1.8.2 - google-cloud-recommendations-ai: 0.7.1 - google-cloud-resource-manager: 1.10.1 - google-cloud-spanner: 3.36.0 - google-cloud-storage: 1.44.0 - google-cloud-translate: 3.11.2 - google-cloud-videointelligence: 2.11.3 - google-cloud-vision: 2.8.0 - google-crc32c: 1.5.0 - google-pasta: 0.2.0 - google-resumable-media: 2.5.0 - googleapis-common-protos: 1.59.1 - gplearn: 0.4.2 - gpustat: 1.0.0 - gpxpy: 1.5.0 - graphviz: 0.20.1 - greenlet: 2.0.2 - grpc-google-iam-v1: 0.12.6 - grpcio: 1.51.3 - grpcio-status: 1.48.2 - gviz-api: 1.10.0 - gym: 0.26.2 - gym-notices: 0.0.8 - gymnasium: 0.26.3 - gymnasium-notices: 0.0.1 - h11: 0.14.0 - h2o: 3.42.0.1 - h5py: 3.9.0 - haversine: 2.8.0 - hdfs: 2.7.0 - hep-ml: 0.7.2 - hijri-converter: 2.3.1 - hmmlearn: 0.3.0 - holidays: 0.24 - holoviews: 1.16.2 - hpsklearn: 0.1.0 - html5lib: 1.1 - htmlmin: 0.1.12 - httplib2: 0.21.0 - httptools: 0.6.0 - huggingface-hub: 0.16.4 - humanize: 4.7.0 - hunspell: 0.5.5 - husl: 4.0.3 - hydra-slayer: 0.4.1 - hyperopt: 0.2.7 - hypertools: 0.8.0 - ibis-framework: 6.0.0 - idna: 3.4 - igraph: 0.10.5 - imagecodecs: 2023.7.4 - imagehash: 4.3.1 - imageio: 2.31.1 - imbalanced-learn: 0.10.1 - imgaug: 0.4.0 - implicit: 0.5.2 - importlib-metadata: 6.8.0 - importlib-resources: 5.12.0 - inequality: 1.0.0 - iniconfig: 2.0.0 - inquirer: 3.1.3 - ipydatawidgets: 4.3.5 - ipykernel: 6.23.3 - ipyleaflet: 0.17.3 - ipympl: 0.7.0 - ipython: 8.14.0 - ipython-genutils: 0.2.0 - ipython-sql: 0.5.0 - ipyvolume: 0.6.3 - ipyvue: 1.9.2 - ipyvuetify: 1.8.10 - ipywebrtc: 0.6.0 - ipywidgets: 7.7.1 - isoduration: 20.11.0 - isort: 5.12.0 - isoweek: 1.3.3 - itsdangerous: 2.1.2 - janome: 0.5.0 - jaraco.classes: 3.2.3 - jax: 0.4.13 - jaxlib: 0.4.13+cuda11.cudnn86 - jedi: 0.18.2 - jeepney: 0.8.0 - jieba: 0.42.1 - jinja2: 3.1.2 - jmespath: 1.0.1 - joblib: 1.3.0 - json5: 0.9.14 - jsonpatch: 1.32 - jsonpointer: 2.0 - jsonschema: 4.17.3 - jupyter-client: 8.3.0 - jupyter-console: 6.6.3 - jupyter-core: 5.3.1 - jupyter-events: 0.6.3 - jupyter-http-over-ws: 0.0.8 - jupyter-lsp: 1.5.1 - jupyter-server: 2.7.0 - jupyter-server-fileid: 0.9.0 - jupyter-server-mathjax: 0.2.6 - jupyter-server-proxy: 4.0.0 - jupyter-server-terminals: 0.4.4 - jupyter-server-ydoc: 0.8.0 - jupyter-ydoc: 0.2.4 - jupyterlab: 3.6.5 - jupyterlab-git: 0.41.0 - jupyterlab-lsp: 4.2.0 - jupyterlab-pygments: 0.2.2 - jupyterlab-server: 2.23.0 - jupyterlab-widgets: 3.0.7 - jupytext: 1.14.6 - kaggle: 1.5.15 - kaggle-environments: 1.12.0 - keras: 2.12.0 - keras-cv: 0.5.0 - keras-tuner: 1.3.5 - keyring: 24.2.0 - keyrings.google-artifactregistry-auth: 1.1.2 - kfp: 2.0.1 - kfp-pipeline-spec: 0.2.2 - kfp-server-api: 2.0.0 - kiwisolver: 1.4.4 - kmapper: 2.0.1 - kmodes: 0.12.2 - korean-lunar-calendar: 0.3.1 - kornia: 0.6.12 - kt-legacy: 1.0.5 - kubernetes: 26.1.0 - langcodes: 3.3.0 - langid: 1.1.6 - lazy-loader: 0.2 - lazy-object-proxy: 1.9.0 - learntools: 0.3.4 - leven: 1.0.4 - levenshtein: 0.21.1 - libclang: 16.0.0 - libmambapy: 1.4.7 - libpysal: 4.7.0 - librosa: 0.10.0.post2 - lightgbm: 3.3.2 - lightning: 2.0.5 - lightning-cloud: 0.5.37 - lightning-utilities: 0.9.0 - lime: 0.2.0.1 - line-profiler: 4.0.3 - linkify-it-py: 2.0.2 - llvmlite: 0.40.1 - lml: 0.1.0 - locket: 1.0.0 - lunarcalendar: 0.0.9 - lxml: 4.9.3 - lz4: 4.3.2 - mako: 1.2.4 - mamba: 1.4.7 - mapclassify: 2.5.0 - marisa-trie: 0.8.0 - markdown: 3.4.3 - markdown-it-py: 3.0.0 - markovify: 0.9.4 - markupsafe: 2.1.3 - marshmallow: 3.19.0 - marshmallow-enum: 1.5.1 - matplotlib: 3.7.2 - matplotlib-inline: 0.1.6 - matplotlib-venn: 0.11.9 - mccabe: 0.7.0 - mdit-py-plugins: 0.4.0 - mdurl: 0.1.2 - memory-profiler: 0.61.0 - mercantile: 1.2.1 - mgwr: 2.1.2 - missingno: 0.5.2 - mistune: 0.8.4 - mizani: 0.9.2 - ml-dtypes: 0.2.0 - mlcrate: 0.2.0 - mlens: 0.2.3 - mlxtend: 0.22.0 - mmh3: 4.0.0 - mne: 1.4.2 - mnist: 0.2.2 - mock: 5.0.2 - momepy: 0.6.0 - more-itertools: 9.1.0 - mpld3: 0.5.9 - mpmath: 1.3.0 - msgpack: 1.0.5 - msgpack-numpy: 0.4.8 - multidict: 6.0.4 - multimethod: 1.9.1 - multipledispatch: 1.0.0 - multiprocess: 0.70.14 - munkres: 1.1.4 - murmurhash: 1.0.9 - mypy-extensions: 1.0.0 - nb-conda: 2.2.1 - nb-conda-kernels: 2.3.1 - nbclassic: 1.0.0 - nbclient: 0.5.13 - nbconvert: 6.4.5 - nbdime: 3.2.0 - nbformat: 5.9.0 - nest-asyncio: 1.5.6 - netcdf4: 1.6.4 - networkx: 3.1 - nibabel: 5.1.0 - nilearn: 0.10.1 - ninja: 1.11.1 - nltk: 3.2.4 - nose: 1.3.7 - notebook: 6.5.4 - notebook-executor: 0.2 - notebook-shim: 0.2.3 - numba: 0.57.1 - numexpr: 2.8.4 - numpy: 1.25.0 - nvidia-ml-py: 11.495.46 - nvtx: 0.2.5 - oauth2client: 4.1.3 - oauthlib: 3.2.2 - objsize: 0.6.1 - odfpy: 1.4.1 - olefile: 0.46 - onnx: 1.14.0 - opencensus: 0.11.2 - opencensus-context: 0.1.3 - opencv-contrib-python: 4.8.0.74 - opencv-python: 4.8.0.74 - opencv-python-headless: 4.8.0.74 - openpyxl: 3.1.2 - openslide-python: 1.2.0 - opentelemetry-api: 1.18.0 - opentelemetry-exporter-otlp: 1.18.0 - opentelemetry-exporter-otlp-proto-common: 1.18.0 - opentelemetry-exporter-otlp-proto-grpc: 1.18.0 - opentelemetry-exporter-otlp-proto-http: 1.18.0 - opentelemetry-proto: 1.18.0 - opentelemetry-sdk: 1.18.0 - opentelemetry-semantic-conventions: 0.39b0 - opt-einsum: 3.3.0 - optax: 0.1.5 - optuna: 3.2.0 - orbax-checkpoint: 0.2.7 - ordered-set: 4.1.0 - orderedmultidict: 1.0.1 - orjson: 3.9.1 - ortools: 9.4.1874 - osmnx: 1.1.1 - overrides: 7.3.1 - packaging: 21.3 - pandas: 2.0.2 - pandas-datareader: 0.10.0 - pandas-profiling: 3.6.6 - pandas-summary: 0.2.0 - pandasql: 0.7.3 - pandocfilters: 1.5.0 - panel: 1.2.0 - papermill: 2.4.0 - param: 1.13.0 - parso: 0.8.3 - parsy: 2.1 - partd: 1.4.0 - path: 16.7.1 - path.py: 12.5.0 - pathos: 0.3.0 - pathtools: 0.1.2 - pathy: 0.10.1 - patsy: 0.5.3 - pdf2image: 1.16.3 - pexpect: 4.8.0 - phik: 0.12.3 - pickleshare: 0.7.5 - pillow: 9.5.0 - pip: 23.1.2 - pkgutil-resolve-name: 1.3.10 - platformdirs: 3.8.1 - plotly: 5.15.0 - plotly-express: 0.4.1 - plotnine: 0.10.1 - pluggy: 1.0.0 - pointpats: 2.3.0 - polars: 0.18.6 - polyglot: 16.7.4 - pooch: 1.6.0 - pox: 0.3.2 - ppca: 0.0.4 - ppft: 1.7.6.6 - preprocessing: 0.1.13 - preshed: 3.0.8 - prettytable: 3.8.0 - progressbar2: 4.2.0 - prometheus-client: 0.17.0 - promise: 2.3 - prompt-toolkit: 3.0.38 - pronouncing: 0.2.0 - prophet: 1.1.1 - proto-plus: 1.22.3 - protobuf: 4.21.12 - psutil: 5.9.5 - ptxcompiler: 0.8.1 - ptyprocess: 0.7.0 - pudb: 2022.1.3 - pulp: 2.7.0 - pure-eval: 0.2.2 - py-cpuinfo: 9.0.0 - py-lz4framed: 0.14.0 - py-spy: 0.3.14 - py4j: 0.10.9.7 - pyaml: 23.7.0 - pyarabic: 0.6.15 - pyarrow: 11.0.0 - pyasn1: 0.5.0 - pyasn1-modules: 0.3.0 - pyastronomy: 0.19.0 - pybind11: 2.10.4 - pyclipper: 1.3.0.post4 - pycodestyle: 2.10.0 - pycolmap: 0.4.0 - pycosat: 0.6.4 - pycparser: 2.21 - pycryptodome: 3.18.0 - pyct: 0.5.0 - pycuda: 2022.2.2 - pydantic: 1.10.10 - pydegensac: 0.1.2 - pydicom: 2.4.1 - pydocstyle: 6.3.0 - pydot: 1.4.2 - pydub: 0.25.1 - pyemd: 1.0.0 - pyerfa: 2.0.0.3 - pyexcel-io: 0.6.6 - pyexcel-ods: 0.6.0 - pyfasttext: 0.4.6 - pyflakes: 3.0.1 - pygltflib: 1.15.6 - pygments: 2.15.1 - pyjwt: 2.7.0 - pykalman: 0.9.5 - pyldavis: 3.2.2 - pylibraft: 23.6.2 - pylint: 2.17.4 - pymc3: 3.11.5 - pymeeus: 0.5.12 - pymongo: 3.13.0 - pympler: 1.0.1 - pynndescent: 0.5.10 - pynvml: 11.4.1 - pynvrtc: 9.2 - pyocr: 0.8.3 - pyopenssl: 23.2.0 - pyparsing: 3.1.0 - pypdf: 3.12.0 - pyproj: 3.6.0 - pyrsistent: 0.19.3 - pysal: 23.1 - pyshp: 2.3.1 - pysocks: 1.7.1 - pytesseract: 0.3.10 - pytest: 7.4.0 - python-bidi: 0.4.2 - python-dateutil: 2.8.2 - python-dotenv: 1.0.0 - python-editor: 1.0.4 - python-igraph: 0.10.5 - python-json-logger: 2.0.7 - python-levenshtein: 0.21.1 - python-louvain: 0.16 - python-lsp-jsonrpc: 1.0.0 - python-lsp-server: 1.7.4 - python-multipart: 0.0.6 - python-slugify: 8.0.1 - python-utils: 3.7.0 - pythreejs: 2.4.2 - pytoolconfig: 1.2.5 - pytools: 2023.1 - pytorch-ignite: 0.4.12 - pytorch-lightning: 2.0.4 - pytz: 2023.3 - pyu2f: 0.1.5 - pyupset: 0.1.1.post7 - pyviz-comms: 2.3.2 - pywavelets: 1.4.1 - pyyaml: 6.0 - pyzmq: 25.1.0 - qgrid: 1.3.1 - qtconsole: 5.4.3 - qtpy: 2.3.1 - quantecon: 0.7.1 - quantities: 0.14.1 - qudida: 0.0.4 - raft-dask: 23.6.2 - randomgen: 1.23.1 - rapidfuzz: 3.1.1 - rasterio: 1.3.8 - rasterstats: 0.19.0 - ray: 2.5.1 - ray-cpp: 2.5.1 - readchar: 4.0.5 - regex: 2023.6.3 - requests: 2.31.0 - requests-oauthlib: 1.3.1 - requests-toolbelt: 0.10.1 - responses: 0.18.0 - retrying: 1.3.4 - rfc3339-validator: 0.1.4 - rfc3986-validator: 0.1.1 - rgf-python: 3.12.0 - rich: 13.4.2 - rmm: 23.6.0 - rope: 1.9.0 - rsa: 4.9 - rtree: 1.0.1 - ruamel-yaml-conda: 0.15.100 - ruamel.yaml: 0.17.32 - ruamel.yaml.clib: 0.2.7 - s2sphere: 0.2.5 - s3fs: 2023.6.0 - s3transfer: 0.6.1 - safetensors: 0.3.1 - scattertext: 0.1.19 - scikit-image: 0.21.0 - scikit-learn: 1.2.2 - scikit-learn-intelex: 2023.1.1 - scikit-multilearn: 0.2.0 - scikit-optimize: 0.9.0 - scikit-plot: 0.3.7 - scikit-surprise: 1.1.3 - scipy: 1.11.1 - seaborn: 0.12.2 - secretstorage: 3.3.3 - segment-anything: 1.0 - segregation: 2.4.2 - semver: 3.0.1 - send2trash: 1.8.2 - sentencepiece: 0.1.99 - sentry-sdk: 1.27.1 - setproctitle: 1.3.2 - setuptools: 59.8.0 - setuptools-git: 1.2 - setuptools-scm: 7.1.0 - shap: 0.42.0 - shapely: 1.8.5.post1 - shellingham: 1.5.1 - simpervisor: 1.0.0 - simpleitk: 2.2.1 - simplejson: 3.19.1 - six: 1.16.0 - sklearn-pandas: 2.2.0 - slicer: 0.0.7 - smart-open: 6.3.0 - smhasher: 0.150.1 - smmap: 5.0.0 - sniffio: 1.3.0 - snowballstemmer: 2.2.0 - snuggs: 1.4.7 - sortedcontainers: 2.4.0 - soundfile: 0.12.1 - soupsieve: 2.4.1 - soxr: 0.3.5 - spacy: 3.5.4 - spacy-legacy: 3.0.12 - spacy-loggers: 1.0.4 - spaghetti: 1.7.4 - spectral: 0.23.1 - spglm: 1.0.8 - sphinx-rtd-theme: 0.2.4 - spint: 1.0.7 - splot: 1.1.5.post1 - spopt: 0.5.0 - spreg: 1.3.2 - spvcm: 0.3.0 - sqlalchemy: 2.0.17 - sqlglot: 17.3.0 - sqlparse: 0.4.4 - squarify: 0.4.3 - srsly: 2.4.6 - stack-data: 0.6.2 - starlette: 0.27.0 - starsessions: 1.3.0 - statsmodels: 0.14.0 - stemming: 1.0.1 - stop-words: 2018.7.23 - stopit: 1.1.2 - stumpy: 1.11.1 - sympy: 1.12 - tables: 3.8.0 - tabulate: 0.9.0 - tangled-up-in-unicode: 0.2.0 - tbb: 2021.9.0 - tblib: 1.7.0 - tenacity: 8.2.2 - tensorboard: 2.12.3 - tensorboard-data-server: 0.7.1 - tensorboard-plugin-profile: 2.13.0 - tensorboardx: 2.6 - tensorflow: 2.12.0 - tensorflow-addons: 0.20.0 - tensorflow-cloud: 0.1.16 - tensorflow-datasets: 4.9.2 - tensorflow-decision-forests: 1.4.0 - tensorflow-estimator: 2.12.0 - tensorflow-gcs-config: 2.12.0 - tensorflow-hub: 0.12.0 - tensorflow-io: 0.31.0 - tensorflow-io-gcs-filesystem: 0.31.0 - tensorflow-metadata: 0.14.0 - tensorflow-probability: 0.20.1 - tensorflow-serving-api: 2.12.1 - tensorflow-text: 2.12.1 - tensorflow-transform: 0.14.0 - tensorflowjs: 3.15.0 - tensorpack: 0.11 - tensorstore: 0.1.40 - termcolor: 2.3.0 - terminado: 0.17.1 - testpath: 0.6.0 - text-unidecode: 1.3 - textblob: 0.17.1 - texttable: 1.6.7 - textwrap3: 0.9.2 - theano: 1.0.5 - theano-pymc: 1.1.2 - thinc: 8.1.10 - threadpoolctl: 3.1.0 - tifffile: 2023.4.12 - timm: 0.9.2 - tinycss2: 1.2.1 - tobler: 0.10 - tokenizers: 0.13.3 - toml: 0.10.2 - tomli: 2.0.1 - tomlkit: 0.11.8 - toolz: 0.12.0 - torch: 2.0.0 - torchaudio: 2.0.1 - torchdata: 0.6.0 - torchinfo: 1.8.0 - torchmetrics: 1.0.0 - torchtext: 0.15.1 - torchvision: 0.15.1 - tornado: 6.3.2 - tpot: 0.12.0 - tqdm: 4.65.0 - traceml: 1.0.8 - traitlets: 5.9.0 - traittypes: 0.2.1 - transformers: 4.30.2 - treelite: 3.2.0 - treelite-runtime: 3.2.0 - trueskill: 0.4.5 - tsfresh: 0.20.1 - typeguard: 2.13.3 - typer: 0.9.0 - typing-extensions: 4.6.3 - typing-inspect: 0.9.0 - typing-utils: 0.1.0 - tzdata: 2023.3 - tzlocal: 5.0.1 - uc-micro-py: 1.0.2 - ucx-py: 0.32.0 - ujson: 5.8.0 - umap-learn: 0.5.3 - unicodedata2: 15.0.0 - unidecode: 1.3.6 - update-checker: 0.18.0 - uri-template: 1.3.0 - uritemplate: 3.0.1 - urllib3: 1.26.16 - urwid: 2.1.2 - urwid-readline: 0.13 - uvicorn: 0.22.0 - uvloop: 0.17.0 - vaex: 4.16.0 - vaex-astro: 0.9.3 - vaex-core: 4.16.1 - vaex-hdf5: 0.14.1 - vaex-jupyter: 0.8.1 - vaex-ml: 0.18.1 - vaex-server: 0.8.1 - vaex-viz: 0.5.4 - vecstack: 0.4.0 - virtualenv: 20.21.0 - visions: 0.7.5 - vowpalwabbit: 9.8.0 - vtk: 9.2.6 - wand: 0.6.11 - wandb: 0.15.5 - wasabi: 1.1.2 - watchfiles: 0.19.0 - wavio: 0.0.7 - wcwidth: 0.2.6 - webcolors: 1.13 - webencodings: 0.5.1 - websocket-client: 1.6.1 - websockets: 11.0.3 - werkzeug: 2.3.6 - wfdb: 4.1.2 - whatthepatch: 1.0.5 - wheel: 0.40.0 - widgetsnbextension: 3.6.4 - witwidget: 1.8.1 - woodwork: 0.24.0 - wordbatch: 1.4.9 - wordcloud: 1.9.2 - wordsegment: 1.3.1 - wrapt: 1.14.1 - wurlitzer: 3.0.3 - xarray: 2023.6.0 - xarray-einstats: 0.5.1 - xgboost: 1.7.6 - xvfbwrapper: 0.2.9 - xxhash: 3.2.0 - xyzservices: 2023.5.0 - y-py: 0.5.9 - yapf: 0.40.1 - yarl: 1.9.2 - ydata-profiling: 4.3.1 - yellowbrick: 1.5 - ypy-websocket: 0.8.2 - zict: 3.0.0 - zipp: 3.15.0 - zstandard: 0.19.0 * System: - OS: Linux - architecture: - 64bit - - processor: x86_64 - python: 3.10.12 - release: 5.15.109+ - version: #1 SMP Thu Jul 13 10:52:19 UTC 2023

More info

I provide the kaggle notebook to produce quickly

https://www.kaggle.com/inhanhv/resume-not-same-with-from-scratch-training

I also google the behaviour and find related post:

https://lightning.ai/forums/t/resuming-training-gives-different-model-result-weights/2677

dinhanhx commented 1 year ago

With larger scale (bigger model, bigger dataset, bigger batch size), the differences are noticeable.

In the following pictures, the purple line is the resumed version of the blue line. And the green line is the training without interruption. As we can see, the blue line and the green line are the same until the interruption. After the interruption, the purple line is very different from the green line.

image image

awaelchli commented 1 year ago

@dinhanhx This is expected, because after resuming, the random state of the program (e.g. in torch) is different than it was when it stopped. This is very much expected. What matters is that the training loss converges to the same result in the end, this seems to be the case in your experiments (the curves are not identical, but their average value is the same). Please let me know if I should explain it in more detail.

Restoring the random state to exactly the way it was when stopped is highly non-trivial. We investigated this in the past but found it too complex, while at the same time it is rarely needed in practice.

aweinmann commented 1 year ago

17543 seems related

dinhanhx commented 1 year ago

What matters is that the training loss converges to the same result in the end, this seems to be the case in your experiments (the curves are not identical, but their average value is the same).

I have noticed the similar pattern with my other training experiments other than this issue. I guess the difference would not affect the final accuracy. Thanks @awaelchli for explaining.

awaelchli commented 1 year ago

I have noticed the similar pattern with my other training experiments other than this issue

Just to clarify, I'm not sure what you are saying. Is it that in general you agree with me and that training experiments from the past have shown this behavior, but you are saying this particular experiment you showed in this issue is not following that? However, the loss curves you posted here follow the same trend and the only difference is the variation is the loss bumps.

awaelchli commented 1 year ago

Hey again @dinhanhx I just want to make sure I have all the information here. So is the explanation consistent with all your experiments and the concerns resolved? Or is there still something unclear? Please let me know so we can either close the ticket or look into it if something is still not working.

dinhanhx commented 1 year ago

So is the explanation consistent with all your experiments and the concerns resolved?

Yes. Everything is clear, I suppose.