Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.33k stars 3.38k forks source link

on_epoch=True reduction of low precision types (bf16, etc) results in very inaccurate metrics #18620

Closed MF-FOOM closed 1 year ago

MF-FOOM commented 1 year ago

Bug description

When training with a low precision type (fp16, bf16, etc) logging loss/etc values via self.log(..., on_epoch=True) will yield really inaccurate reductions (whether mean, sum, etc).

This is because instead of using the torch functions for these operations (torch.mean, torch.sum, etc), lightning currently does the reduction manually, simply adding up new values as they're logged (and then dividing at the end in the case of mean).

The issue with this is that, with low precision types, float non-associativity becomes a really big deal and the accumulated values can get stuck if logged values aren't large enough to push the accumulator to the next representable number (i.e. since 256 + 1 == 256 with bfloat16).

torch.mean, torch.sum, etc all help mitigate this under the hood (i.e. such that torch.sum([256, 1, 1]) == 258 instead of getting stuck at 256), but since lightning does not use these functions, precision greatly suffers.

However, even if we were to refactor the accumulation logic to use these torch operations, I still worry doing reduction on these small types is simply not precise enough, and is an easy trap for users to fall into without noticing. I personally have been casting my loss values to float32 to remedy this.

I can see a couple possible solutions:

What version are you seeing the problem on?

master

How to reproduce the bug

I've demonstrated how bad these precision issues can be here:

https://colab.research.google.com/drive/1u6NpRSzHBmNsQ1n18rIdtRBQaJOMp1mM?usp=sharing

Observe that I'm logging a fixed constant value of 1 on each validation step, yet the reduced value comes out to 0.51171875

Error messages and logs

N/A

Environment

Current environment * CUDA: - GPU: - Tesla T4 - available: True - version: 11.8 * Lightning: - torch: 2.0.1+cu118 - torchaudio: 2.0.2+cu118 - torchdata: 0.6.1 - torchsummary: 1.5.1 - torchtext: 0.15.2 - torchvision: 0.15.2+cu118 * Packages: - absl-py: 1.4.0 - aiohttp: 3.8.5 - aiosignal: 1.3.1 - alabaster: 0.7.13 - albumentations: 1.3.1 - altair: 4.2.2 - anyio: 3.7.1 - appdirs: 1.4.4 - argon2-cffi: 23.1.0 - argon2-cffi-bindings: 21.2.0 - array-record: 0.4.1 - arviz: 0.15.1 - astropy: 5.3.3 - astunparse: 1.6.3 - async-timeout: 4.0.3 - attrs: 23.1.0 - audioread: 3.0.0 - autograd: 1.6.2 - babel: 2.12.1 - backcall: 0.2.0 - beautifulsoup4: 4.11.2 - bleach: 6.0.0 - blinker: 1.4 - blis: 0.7.10 - blosc2: 2.0.0 - bokeh: 3.2.2 - bqplot: 0.12.40 - branca: 0.6.0 - build: 1.0.3 - cachecontrol: 0.13.1 - cachetools: 5.3.1 - catalogue: 2.0.9 - certifi: 2023.7.22 - cffi: 1.15.1 - chardet: 5.2.0 - charset-normalizer: 3.2.0 - chex: 0.1.7 - click: 8.1.7 - click-plugins: 1.1.1 - cligj: 0.7.2 - cloudpickle: 2.2.1 - cmake: 3.27.4.1 - cmdstanpy: 1.1.0 - colorcet: 3.0.1 - colorlover: 0.3.0 - colour: 0.1.5 - community: 1.0.0b1 - confection: 0.1.2 - cons: 0.4.6 - contextlib2: 21.6.0 - contourpy: 1.1.0 - convertdate: 2.4.0 - cryptography: 41.0.3 - cufflinks: 0.17.3 - cupy-cuda11x: 11.0.0 - cvxopt: 1.3.2 - cvxpy: 1.3.2 - cycler: 0.11.0 - cymem: 2.0.7 - cython: 3.0.2 - dask: 2023.8.1 - datascience: 0.17.6 - db-dtypes: 1.1.1 - dbus-python: 1.2.18 - debugpy: 1.6.6 - decorator: 4.4.2 - defusedxml: 0.7.1 - distributed: 2023.8.1 - distro: 1.7.0 - dlib: 19.24.2 - dm-tree: 0.1.8 - docutils: 0.18.1 - dopamine-rl: 4.0.6 - duckdb: 0.8.1 - earthengine-api: 0.1.368 - easydict: 1.10 - ecos: 2.0.12 - editdistance: 0.6.2 - eerepr: 0.0.4 - en-core-web-sm: 3.6.0 - entrypoints: 0.4 - ephem: 4.1.4 - et-xmlfile: 1.1.0 - etils: 1.4.1 - etuples: 0.3.9 - exceptiongroup: 1.1.3 - fastai: 2.7.12 - fastcore: 1.5.29 - fastdownload: 0.0.7 - fastjsonschema: 2.18.0 - fastprogress: 1.0.3 - fastrlock: 0.8.2 - filelock: 3.12.2 - fiona: 1.9.4.post1 - firebase-admin: 5.3.0 - flask: 2.2.5 - flatbuffers: 23.5.26 - flax: 0.7.2 - folium: 0.14.0 - fonttools: 4.42.1 - frozendict: 2.3.8 - frozenlist: 1.4.0 - fsspec: 2023.6.0 - future: 0.18.3 - gast: 0.4.0 - gcsfs: 2023.6.0 - gdal: 3.4.3 - gdown: 4.6.6 - geemap: 0.26.0 - gensim: 4.3.2 - geocoder: 1.38.1 - geographiclib: 2.0 - geopandas: 0.13.2 - geopy: 2.3.0 - gin-config: 0.5.0 - glob2: 0.7 - google: 2.0.3 - google-api-core: 2.11.1 - google-api-python-client: 2.84.0 - google-auth: 2.17.3 - google-auth-httplib2: 0.1.0 - google-auth-oauthlib: 1.0.0 - google-cloud-bigquery: 3.10.0 - google-cloud-bigquery-connection: 1.12.1 - google-cloud-bigquery-storage: 2.22.0 - google-cloud-core: 2.3.3 - google-cloud-datastore: 2.15.2 - google-cloud-firestore: 2.11.1 - google-cloud-functions: 1.13.2 - google-cloud-language: 2.9.1 - google-cloud-storage: 2.8.0 - google-cloud-translate: 3.11.3 - google-colab: 1.0.0 - google-crc32c: 1.5.0 - google-pasta: 0.2.0 - google-resumable-media: 2.6.0 - googleapis-common-protos: 1.60.0 - googledrivedownloader: 0.4 - graphviz: 0.20.1 - greenlet: 2.0.2 - grpc-google-iam-v1: 0.12.6 - grpcio: 1.57.0 - grpcio-status: 1.48.2 - gspread: 3.4.2 - gspread-dataframe: 3.3.1 - gym: 0.25.2 - gym-notices: 0.0.8 - h5netcdf: 1.2.0 - h5py: 3.9.0 - holidays: 0.32 - holoviews: 1.17.1 - html5lib: 1.1 - httpimport: 1.3.1 - httplib2: 0.22.0 - humanize: 4.7.0 - hyperopt: 0.2.7 - idna: 3.4 - imageio: 2.31.3 - imageio-ffmpeg: 0.4.8 - imagesize: 1.4.1 - imbalanced-learn: 0.10.1 - imgaug: 0.4.0 - importlib-metadata: 6.8.0 - importlib-resources: 6.0.1 - imutils: 0.5.4 - inflect: 7.0.0 - iniconfig: 2.0.0 - intel-openmp: 2023.2.0 - ipyevents: 2.0.2 - ipyfilechooser: 0.6.0 - ipykernel: 5.5.6 - ipyleaflet: 0.17.3 - ipython: 7.34.0 - ipython-genutils: 0.2.0 - ipython-sql: 0.5.0 - ipytree: 0.2.2 - ipywidgets: 7.7.1 - itsdangerous: 2.1.2 - jax: 0.4.14 - jaxlib: 0.4.14+cuda11.cudnn86 - jeepney: 0.7.1 - jieba: 0.42.1 - jinja2: 3.1.2 - joblib: 1.3.2 - jsonpickle: 3.0.2 - jsonschema: 4.19.0 - jsonschema-specifications: 2023.7.1 - jupyter-client: 6.1.12 - jupyter-console: 6.1.0 - jupyter-core: 5.3.1 - jupyter-server: 1.24.0 - jupyterlab-pygments: 0.2.2 - jupyterlab-widgets: 3.0.8 - kaggle: 1.5.16 - keras: 2.13.1 - keyring: 23.5.0 - kiwisolver: 1.4.5 - langcodes: 3.3.0 - launchpadlib: 1.10.16 - lazr.restfulclient: 0.14.4 - lazr.uri: 1.0.6 - lazy-loader: 0.3 - libclang: 16.0.6 - librosa: 0.10.1 - lightgbm: 4.0.0 - linkify-it-py: 2.0.2 - lit: 16.0.6 - llvmlite: 0.39.1 - locket: 1.0.0 - logical-unification: 0.4.6 - lunarcalendar: 0.0.9 - lxml: 4.9.3 - markdown: 3.4.4 - markdown-it-py: 3.0.0 - markupsafe: 2.1.3 - matplotlib: 3.7.1 - matplotlib-inline: 0.1.6 - matplotlib-venn: 0.11.9 - mdit-py-plugins: 0.4.0 - mdurl: 0.1.2 - minikanren: 1.0.3 - missingno: 0.5.2 - mistune: 0.8.4 - mizani: 0.9.3 - mkl: 2023.2.0 - ml-dtypes: 0.2.0 - mlxtend: 0.22.0 - more-itertools: 10.1.0 - moviepy: 1.0.3 - mpmath: 1.3.0 - msgpack: 1.0.5 - multidict: 6.0.4 - multipledispatch: 1.0.0 - multitasking: 0.0.11 - murmurhash: 1.0.9 - music21: 9.1.0 - natsort: 8.4.0 - nbclassic: 1.0.0 - nbclient: 0.8.0 - nbconvert: 6.5.4 - nbformat: 5.9.2 - nest-asyncio: 1.5.7 - networkx: 3.1 - nibabel: 4.0.2 - nltk: 3.8.1 - notebook: 6.5.5 - notebook-shim: 0.2.3 - numba: 0.56.4 - numexpr: 2.8.5 - numpy: 1.23.5 - oauth2client: 4.1.3 - oauthlib: 3.2.2 - opencv-contrib-python: 4.8.0.76 - opencv-python: 4.8.0.76 - opencv-python-headless: 4.8.0.76 - openpyxl: 3.1.2 - opt-einsum: 3.3.0 - optax: 0.1.7 - orbax-checkpoint: 0.3.5 - osqp: 0.6.2.post8 - packaging: 23.1 - pandas: 1.5.3 - pandas-datareader: 0.10.0 - pandas-gbq: 0.17.9 - pandocfilters: 1.5.0 - panel: 1.2.2 - param: 1.13.0 - parso: 0.8.3 - partd: 1.4.0 - pathlib: 1.0.1 - pathy: 0.10.2 - patsy: 0.5.3 - pexpect: 4.8.0 - pickleshare: 0.7.5 - pillow: 9.4.0 - pip: 23.1.2 - pip-tools: 6.13.0 - platformdirs: 3.10.0 - plotly: 5.15.0 - plotnine: 0.12.3 - pluggy: 1.3.0 - polars: 0.17.3 - pooch: 1.7.0 - portpicker: 1.5.2 - prefetch-generator: 1.0.3 - preshed: 3.0.8 - prettytable: 3.8.0 - proglog: 0.1.10 - progressbar2: 4.2.0 - prometheus-client: 0.17.1 - promise: 2.3 - prompt-toolkit: 3.0.39 - prophet: 1.1.4 - proto-plus: 1.22.3 - protobuf: 3.20.3 - psutil: 5.9.5 - psycopg2: 2.9.7 - ptyprocess: 0.7.0 - py-cpuinfo: 9.0.0 - py4j: 0.10.9.7 - pyarrow: 9.0.0 - pyasn1: 0.5.0 - pyasn1-modules: 0.3.0 - pycocotools: 2.0.7 - pycparser: 2.21 - pyct: 0.5.0 - pydantic: 1.10.12 - pydata-google-auth: 1.8.2 - pydot: 1.4.2 - pydot-ng: 2.0.0 - pydotplus: 2.0.2 - pydrive: 1.3.1 - pydrive2: 1.6.3 - pyerfa: 2.0.0.3 - pygame: 2.5.1 - pygments: 2.16.1 - pygobject: 3.42.1 - pyjwt: 2.3.0 - pymc: 5.7.2 - pymeeus: 0.5.12 - pymystem3: 0.2.0 - pyopengl: 3.1.7 - pyopenssl: 23.2.0 - pyparsing: 3.1.1 - pyperclip: 1.8.2 - pyproj: 3.6.0 - pyproject-hooks: 1.0.0 - pyshp: 2.3.1 - pysocks: 1.7.1 - pytensor: 2.14.2 - pytest: 7.4.1 - python-apt: 0.0.0 - python-box: 7.1.1 - python-dateutil: 2.8.2 - python-louvain: 0.16 - python-slugify: 8.0.1 - python-utils: 3.7.0 - pytz: 2023.3.post1 - pyviz-comms: 3.0.0 - pywavelets: 1.4.1 - pyyaml: 6.0.1 - pyzmq: 23.2.1 - qdldl: 0.1.7.post0 - qudida: 0.0.4 - ratelim: 0.1.6 - referencing: 0.30.2 - regex: 2023.6.3 - requests: 2.31.0 - requests-oauthlib: 1.3.1 - requirements-parser: 0.5.0 - rich: 13.5.2 - rpds-py: 0.10.2 - rpy2: 3.4.2 - rsa: 4.9 - scikit-image: 0.19.3 - scikit-learn: 1.2.2 - scipy: 1.11.2 - scooby: 0.7.2 - scs: 3.2.3 - seaborn: 0.12.2 - secretstorage: 3.3.1 - send2trash: 1.8.2 - setuptools: 67.7.2 - shapely: 2.0.1 - six: 1.16.0 - sklearn-pandas: 2.2.0 - smart-open: 6.4.0 - sniffio: 1.3.0 - snowballstemmer: 2.2.0 - sortedcontainers: 2.4.0 - soundfile: 0.12.1 - soupsieve: 2.5 - soxr: 0.3.6 - spacy: 3.6.1 - spacy-legacy: 3.0.12 - spacy-loggers: 1.0.4 - sphinx: 5.0.2 - sphinxcontrib-applehelp: 1.0.7 - sphinxcontrib-devhelp: 1.0.5 - sphinxcontrib-htmlhelp: 2.0.4 - sphinxcontrib-jsmath: 1.0.1 - sphinxcontrib-qthelp: 1.0.6 - sphinxcontrib-serializinghtml: 1.1.9 - sqlalchemy: 2.0.20 - sqlparse: 0.4.4 - srsly: 2.4.7 - statsmodels: 0.14.0 - sympy: 1.12 - tables: 3.8.0 - tabulate: 0.9.0 - tbb: 2021.10.0 - tblib: 2.0.0 - tenacity: 8.2.3 - tensorboard: 2.13.0 - tensorboard-data-server: 0.7.1 - tensorflow: 2.13.0 - tensorflow-datasets: 4.9.2 - tensorflow-estimator: 2.13.0 - tensorflow-gcs-config: 2.13.0 - tensorflow-hub: 0.14.0 - tensorflow-io-gcs-filesystem: 0.33.0 - tensorflow-metadata: 1.14.0 - tensorflow-probability: 0.20.1 - tensorstore: 0.1.41 - termcolor: 2.3.0 - terminado: 0.17.1 - text-unidecode: 1.3 - textblob: 0.17.1 - tf-slim: 1.1.0 - thinc: 8.1.12 - threadpoolctl: 3.2.0 - tifffile: 2023.8.30 - tinycss2: 1.2.1 - toml: 0.10.2 - tomli: 2.0.1 - toolz: 0.12.0 - torch: 2.0.1+cu118 - torchaudio: 2.0.2+cu118 - torchdata: 0.6.1 - torchsummary: 1.5.1 - torchtext: 0.15.2 - torchvision: 0.15.2+cu118 - tornado: 6.3.2 - tqdm: 4.66.1 - traitlets: 5.7.1 - traittypes: 0.2.1 - triton: 2.0.0 - tweepy: 4.13.0 - typer: 0.9.0 - types-setuptools: 68.2.0.0 - typing-extensions: 4.5.0 - tzlocal: 5.0.1 - uc-micro-py: 1.0.2 - uritemplate: 4.1.1 - urllib3: 2.0.4 - vega-datasets: 0.9.0 - wadllib: 1.3.6 - wasabi: 1.1.2 - wcwidth: 0.2.6 - webcolors: 1.13 - webencodings: 0.5.1 - websocket-client: 1.6.2 - werkzeug: 2.3.7 - wheel: 0.41.2 - widgetsnbextension: 3.6.5 - wordcloud: 1.9.2 - wrapt: 1.15.0 - xarray: 2023.7.0 - xarray-einstats: 0.6.0 - xgboost: 1.7.6 - xlrd: 2.0.1 - xyzservices: 2023.7.0 - yarl: 1.9.2 - yellowbrick: 1.5 - yfinance: 0.2.28 - zict: 3.0.0 - zipp: 3.16.2 * System: - OS: Linux - architecture: - 64bit - ELF - processor: x86_64 - python: 3.10.12 - release: 5.15.120+ - version: #1 SMP Wed Aug 30 11:19:59 UTC 2023

More info

No response

cc @carmocca @Blaizzy

Bhavay-2001 commented 1 year ago

Hello @awaelchli, Can I solve this issue? I am a beginner in open-source but would be really happy to help here. Thanks

MF-FOOM commented 1 year ago

@awaelchli @Bhavay-2001 Well, which solution do you think is best? There's not one that's immediately obvious to me as the best one to implement.

I can see a couple possible solutions:

  • Simply add a warning about logging low precision types when on_epoch=True is enabled
  • Use torch reduction operations to mitigate associativity issues
  • Add a warning & use the torch reduction operations
  • Auto cast logged values to float32(or 64?) under the hood
Bhavay-2001 commented 1 year ago

Hi @MF-FOOM, I am too a beginner in open-source. I am not sure about this and would like to discuss this with @awaelchli

awaelchli commented 1 year ago

Here is the relevant code where the accumulation happens. https://github.com/Lightning-AI/lightning/blob/80f131c668f83d21779b0120957466a23b24a5af/src/lightning/pytorch/trainer/connectors/logger_connector/result.py#L205

I vote for converting floating point scalars to full precision before accumulation and against storing all values, because the user doesnt care about the internal representation but rather just the final reduced value. I vote against warnings because the user ignores warnings generally.

So my suggestion is to call .float() before the summation.

As a workaround in the current version, the user can do the same before passing the value to the log method.

MF-FOOM commented 1 year ago

sgtm! @Bhavay-2001 do you still want to give it a try?

Bhavay-2001 commented 1 year ago

Hi @MF-FOOM , maybe you can go ahead with this. I'm facing a little problem understanding it.

@awaelchli can you suggest some open source contributions for beginners?

awaelchli commented 1 year ago

@Bhavay-2001 Thanks for your interest. I suggest that you join the "want-to-contribute" Discord channel and we can find something that fits you.