Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.03k stars 3.36k forks source link

registered buffers' dtype is overridden after __init__ #18982

Open MF-FOOM opened 10 months ago

MF-FOOM commented 10 months ago

Bug description

If I register a float64 tensor to a buffer in the __init__ function of a LightningModule like so:

self.register_buffer("testing_variable", torch.tensor([1,2,3], dtype=torch.float64))

It will get cast into the Trainer's precision type after setup, regardless of whether a different dtype (e.g. float64) was intended.

What version are you seeing the problem on?

v2.1

How to reproduce the bug

https://colab.research.google.com/drive/1UN1JeuiV3LM-LRrATigboJeDH9ehJifY

Error messages and logs

N/A

Environment

Current environment * CUDA: - GPU: - Tesla T4 - available: True - version: 11.8 * Lightning: - lightning: 2.1.1 - lightning-utilities: 0.9.0 - pytorch-lightning: 2.1.1 - torch: 2.1.0+cu118 - torchaudio: 2.1.0+cu118 - torchdata: 0.7.0 - torchmetrics: 1.2.0 - torchsummary: 1.5.1 - torchtext: 0.16.0 - torchvision: 0.16.0+cu118 * Packages: - absl-py: 1.4.0 - aiohttp: 3.8.6 - aiosignal: 1.3.1 - alabaster: 0.7.13 - albumentations: 1.3.1 - altair: 4.2.2 - anyio: 3.7.1 - appdirs: 1.4.4 - argon2-cffi: 23.1.0 - argon2-cffi-bindings: 21.2.0 - array-record: 0.5.0 - arviz: 0.15.1 - astropy: 5.3.4 - astunparse: 1.6.3 - async-timeout: 4.0.3 - atpublic: 4.0 - attrs: 23.1.0 - audioread: 3.0.1 - autograd: 1.6.2 - babel: 2.13.1 - backcall: 0.2.0 - beautifulsoup4: 4.11.2 - bidict: 0.22.1 - bigframes: 0.12.0 - bleach: 6.1.0 - blinker: 1.4 - blis: 0.7.11 - blosc2: 2.0.0 - bokeh: 3.3.0 - bqplot: 0.12.42 - branca: 0.7.0 - build: 1.0.3 - cachecontrol: 0.13.1 - cachetools: 5.3.2 - catalogue: 2.0.10 - certifi: 2023.7.22 - cffi: 1.16.0 - chardet: 5.2.0 - charset-normalizer: 3.3.2 - chex: 0.1.7 - click: 8.1.7 - click-plugins: 1.1.1 - cligj: 0.7.2 - cloudpickle: 2.2.1 - cmake: 3.27.7 - cmdstanpy: 1.2.0 - colorcet: 3.0.1 - colorlover: 0.3.0 - colour: 0.1.5 - community: 1.0.0b1 - confection: 0.1.3 - cons: 0.4.6 - contextlib2: 21.6.0 - contourpy: 1.2.0 - cryptography: 41.0.5 - cufflinks: 0.17.3 - cupy-cuda11x: 11.0.0 - cvxopt: 1.3.2 - cvxpy: 1.3.2 - cycler: 0.12.1 - cymem: 2.0.8 - cython: 3.0.5 - dask: 2023.8.1 - datascience: 0.17.6 - db-dtypes: 1.1.1 - dbus-python: 1.2.18 - debugpy: 1.6.6 - decorator: 4.4.2 - defusedxml: 0.7.1 - diskcache: 5.6.3 - distributed: 2023.8.1 - distro: 1.7.0 - dlib: 19.24.2 - dm-tree: 0.1.8 - docutils: 0.18.1 - dopamine-rl: 4.0.6 - duckdb: 0.9.1 - earthengine-api: 0.1.377 - easydict: 1.11 - ecos: 2.0.12 - editdistance: 0.6.2 - eerepr: 0.0.4 - en-core-web-sm: 3.6.0 - entrypoints: 0.4 - et-xmlfile: 1.1.0 - etils: 1.5.2 - etuples: 0.3.9 - exceptiongroup: 1.1.3 - fastai: 2.7.13 - fastcore: 1.5.29 - fastdownload: 0.0.7 - fastjsonschema: 2.18.1 - fastprogress: 1.0.3 - fastrlock: 0.8.2 - filelock: 3.13.1 - fiona: 1.9.5 - firebase-admin: 5.3.0 - flask: 2.2.5 - flatbuffers: 23.5.26 - flax: 0.7.5 - folium: 0.14.0 - fonttools: 4.44.0 - frozendict: 2.3.8 - frozenlist: 1.4.0 - fsspec: 2023.6.0 - future: 0.18.3 - gast: 0.5.4 - gcsfs: 2023.6.0 - gdal: 3.4.3 - gdown: 4.6.6 - geemap: 0.28.2 - gensim: 4.3.2 - geocoder: 1.38.1 - geographiclib: 2.0 - geopandas: 0.13.2 - geopy: 2.3.0 - gin-config: 0.5.0 - glob2: 0.7 - google: 2.0.3 - google-api-core: 2.11.1 - google-api-python-client: 2.84.0 - google-auth: 2.17.3 - google-auth-httplib2: 0.1.1 - google-auth-oauthlib: 1.0.0 - google-cloud-bigquery: 3.12.0 - google-cloud-bigquery-connection: 1.12.1 - google-cloud-bigquery-storage: 2.22.0 - google-cloud-core: 2.3.3 - google-cloud-datastore: 2.15.2 - google-cloud-firestore: 2.11.1 - google-cloud-functions: 1.13.3 - google-cloud-iam: 2.12.2 - google-cloud-language: 2.9.1 - google-cloud-resource-manager: 1.10.4 - google-cloud-storage: 2.8.0 - google-cloud-translate: 3.11.3 - google-colab: 1.0.0 - google-crc32c: 1.5.0 - google-pasta: 0.2.0 - google-resumable-media: 2.6.0 - googleapis-common-protos: 1.61.0 - googledrivedownloader: 0.4 - graphviz: 0.20.1 - greenlet: 3.0.1 - grpc-google-iam-v1: 0.12.6 - grpcio: 1.59.2 - grpcio-status: 1.48.2 - gspread: 3.4.2 - gspread-dataframe: 3.3.1 - gym: 0.25.2 - gym-notices: 0.0.8 - h5netcdf: 1.3.0 - h5py: 3.9.0 - holidays: 0.36 - holoviews: 1.17.1 - html5lib: 1.1 - httpimport: 1.3.1 - httplib2: 0.22.0 - humanize: 4.7.0 - hyperopt: 0.2.7 - ibis-framework: 6.2.0 - idna: 3.4 - imageio: 2.31.6 - imageio-ffmpeg: 0.4.9 - imagesize: 1.4.1 - imbalanced-learn: 0.10.1 - imgaug: 0.4.0 - importlib-metadata: 6.8.0 - importlib-resources: 6.1.1 - imutils: 0.5.4 - inflect: 7.0.0 - iniconfig: 2.0.0 - install: 1.3.5 - intel-openmp: 2023.2.0 - ipyevents: 2.0.2 - ipyfilechooser: 0.6.0 - ipykernel: 5.5.6 - ipyleaflet: 0.17.4 - ipython: 7.34.0 - ipython-genutils: 0.2.0 - ipython-sql: 0.5.0 - ipytree: 0.2.2 - ipywidgets: 7.7.1 - itsdangerous: 2.1.2 - jax: 0.4.20 - jaxlib: 0.4.20+cuda11.cudnn86 - jeepney: 0.7.1 - jieba: 0.42.1 - jinja2: 3.1.2 - joblib: 1.3.2 - jsonpickle: 3.0.2 - jsonschema: 4.19.2 - jsonschema-specifications: 2023.7.1 - jupyter-client: 6.1.12 - jupyter-console: 6.1.0 - jupyter-core: 5.5.0 - jupyter-server: 1.24.0 - jupyterlab-pygments: 0.2.2 - jupyterlab-widgets: 3.0.9 - kaggle: 1.5.16 - keras: 2.14.0 - keyring: 23.5.0 - kiwisolver: 1.4.5 - langcodes: 3.3.0 - launchpadlib: 1.10.16 - lazr.restfulclient: 0.14.4 - lazr.uri: 1.0.6 - lazy-loader: 0.3 - libclang: 16.0.6 - librosa: 0.10.1 - lida: 0.0.10 - lightgbm: 4.1.0 - lightning: 2.1.1 - lightning-utilities: 0.9.0 - linkify-it-py: 2.0.2 - llmx: 0.0.15a0 - llvmlite: 0.41.1 - locket: 1.0.0 - logical-unification: 0.4.6 - lxml: 4.9.3 - malloy: 2023.1064 - markdown: 3.5.1 - markdown-it-py: 3.0.0 - markupsafe: 2.1.3 - matplotlib: 3.7.1 - matplotlib-inline: 0.1.6 - matplotlib-venn: 0.11.9 - mdit-py-plugins: 0.4.0 - mdurl: 0.1.2 - minikanren: 1.0.3 - missingno: 0.5.2 - mistune: 0.8.4 - mizani: 0.9.3 - mkl: 2023.2.0 - ml-dtypes: 0.2.0 - mlxtend: 0.22.0 - more-itertools: 10.1.0 - moviepy: 1.0.3 - mpmath: 1.3.0 - msgpack: 1.0.7 - multidict: 6.0.4 - multipledispatch: 1.0.0 - multitasking: 0.0.11 - murmurhash: 1.0.10 - music21: 9.1.0 - natsort: 8.4.0 - nbclassic: 1.0.0 - nbclient: 0.9.0 - nbconvert: 6.5.4 - nbformat: 5.9.2 - nest-asyncio: 1.5.8 - networkx: 3.2.1 - nibabel: 4.0.2 - nltk: 3.8.1 - notebook: 6.5.5 - notebook-shim: 0.2.3 - numba: 0.58.1 - numexpr: 2.8.7 - numpy: 1.23.5 - oauth2client: 4.1.3 - oauthlib: 3.2.2 - opencv-contrib-python: 4.8.0.76 - opencv-python: 4.8.0.76 - opencv-python-headless: 4.8.1.78 - openpyxl: 3.1.2 - opt-einsum: 3.3.0 - optax: 0.1.7 - orbax-checkpoint: 0.4.2 - osqp: 0.6.2.post8 - packaging: 23.2 - pandas: 1.5.3 - pandas-datareader: 0.10.0 - pandas-gbq: 0.17.9 - pandas-stubs: 1.5.3.230304 - pandocfilters: 1.5.0 - panel: 1.3.1 - param: 2.0.0 - parso: 0.8.3 - parsy: 2.1 - partd: 1.4.1 - pathlib: 1.0.1 - pathy: 0.10.3 - patsy: 0.5.3 - peewee: 3.17.0 - pexpect: 4.8.0 - pickleshare: 0.7.5 - pillow: 9.4.0 - pip: 23.1.2 - pip-tools: 6.13.0 - platformdirs: 3.11.0 - plotly: 5.15.0 - plotnine: 0.12.4 - pluggy: 1.3.0 - polars: 0.17.3 - pooch: 1.8.0 - portpicker: 1.5.2 - prefetch-generator: 1.0.3 - preshed: 3.0.9 - prettytable: 3.9.0 - proglog: 0.1.10 - progressbar2: 4.2.0 - prometheus-client: 0.18.0 - promise: 2.3 - prompt-toolkit: 3.0.39 - prophet: 1.1.5 - proto-plus: 1.22.3 - protobuf: 3.20.3 - psutil: 5.9.5 - psycopg2: 2.9.9 - ptyprocess: 0.7.0 - py-cpuinfo: 9.0.0 - py4j: 0.10.9.7 - pyarrow: 9.0.0 - pyasn1: 0.5.0 - pyasn1-modules: 0.3.0 - pycocotools: 2.0.7 - pycparser: 2.21 - pyct: 0.5.0 - pydantic: 1.10.13 - pydata-google-auth: 1.8.2 - pydot: 1.4.2 - pydot-ng: 2.0.0 - pydotplus: 2.0.2 - pydrive: 1.3.1 - pydrive2: 1.6.3 - pyerfa: 2.0.1.1 - pygame: 2.5.2 - pygments: 2.16.1 - pygobject: 3.42.1 - pyjwt: 2.3.0 - pymc: 5.7.2 - pymystem3: 0.2.0 - pyopengl: 3.1.7 - pyopenssl: 23.3.0 - pyparsing: 3.1.1 - pyperclip: 1.8.2 - pyproj: 3.6.1 - pyproject-hooks: 1.0.0 - pyshp: 2.3.1 - pysocks: 1.7.1 - pytensor: 2.14.2 - pytest: 7.4.3 - python-apt: 0.0.0 - python-box: 7.1.1 - python-dateutil: 2.8.2 - python-louvain: 0.16 - python-slugify: 8.0.1 - python-utils: 3.8.1 - pytorch-lightning: 2.1.1 - pytz: 2023.3.post1 - pyviz-comms: 3.0.0 - pywavelets: 1.4.1 - pyyaml: 6.0.1 - pyzmq: 23.2.1 - qdldl: 0.1.7.post0 - qudida: 0.0.4 - ratelim: 0.1.6 - referencing: 0.30.2 - regex: 2023.6.3 - requests: 2.31.0 - requests-oauthlib: 1.3.1 - requirements-parser: 0.5.0 - rich: 13.6.0 - rpds-py: 0.12.0 - rpy2: 3.4.2 - rsa: 4.9 - scikit-image: 0.19.3 - scikit-learn: 1.2.2 - scipy: 1.11.3 - scooby: 0.9.2 - scs: 3.2.3 - seaborn: 0.12.2 - secretstorage: 3.3.1 - send2trash: 1.8.2 - setuptools: 67.7.2 - shapely: 2.0.2 - six: 1.16.0 - sklearn-pandas: 2.2.0 - smart-open: 6.4.0 - sniffio: 1.3.0 - snowballstemmer: 2.2.0 - sortedcontainers: 2.4.0 - soundfile: 0.12.1 - soupsieve: 2.5 - soxr: 0.3.7 - spacy: 3.6.1 - spacy-legacy: 3.0.12 - spacy-loggers: 1.0.5 - sphinx: 5.0.2 - sphinxcontrib-applehelp: 1.0.7 - sphinxcontrib-devhelp: 1.0.5 - sphinxcontrib-htmlhelp: 2.0.4 - sphinxcontrib-jsmath: 1.0.1 - sphinxcontrib-qthelp: 1.0.6 - sphinxcontrib-serializinghtml: 1.1.9 - sqlalchemy: 2.0.23 - sqlglot: 17.16.2 - sqlparse: 0.4.4 - srsly: 2.4.8 - stanio: 0.3.0 - statsmodels: 0.14.0 - sympy: 1.12 - tables: 3.8.0 - tabulate: 0.9.0 - tbb: 2021.10.0 - tblib: 3.0.0 - tenacity: 8.2.3 - tensorboard: 2.14.1 - tensorboard-data-server: 0.7.2 - tensorflow: 2.14.0 - tensorflow-datasets: 4.9.3 - tensorflow-estimator: 2.14.0 - tensorflow-gcs-config: 2.14.0 - tensorflow-hub: 0.15.0 - tensorflow-io-gcs-filesystem: 0.34.0 - tensorflow-metadata: 1.14.0 - tensorflow-probability: 0.22.0 - tensorstore: 0.1.45 - termcolor: 2.3.0 - terminado: 0.17.1 - text-unidecode: 1.3 - textblob: 0.17.1 - tf-slim: 1.1.0 - thinc: 8.1.12 - threadpoolctl: 3.2.0 - tifffile: 2023.9.26 - tinycss2: 1.2.1 - toml: 0.10.2 - tomli: 2.0.1 - toolz: 0.12.0 - torch: 2.1.0+cu118 - torchaudio: 2.1.0+cu118 - torchdata: 0.7.0 - torchmetrics: 1.2.0 - torchsummary: 1.5.1 - torchtext: 0.16.0 - torchvision: 0.16.0+cu118 - tornado: 6.3.2 - tqdm: 4.66.1 - traitlets: 5.7.1 - traittypes: 0.2.1 - triton: 2.1.0 - tweepy: 4.14.0 - typer: 0.9.0 - types-pytz: 2023.3.1.1 - types-setuptools: 68.2.0.0 - typing-extensions: 4.5.0 - tzlocal: 5.2 - uc-micro-py: 1.0.2 - uritemplate: 4.1.1 - urllib3: 2.0.7 - vega-datasets: 0.9.0 - wadllib: 1.3.6 - wasabi: 1.1.2 - wcwidth: 0.2.9 - webcolors: 1.13 - webencodings: 0.5.1 - websocket-client: 1.6.4 - werkzeug: 3.0.1 - wheel: 0.41.3 - widgetsnbextension: 3.6.6 - wordcloud: 1.9.2 - wrapt: 1.14.1 - xarray: 2023.7.0 - xarray-einstats: 0.6.0 - xgboost: 2.0.1 - xlrd: 2.0.1 - xxhash: 3.4.1 - xyzservices: 2023.10.1 - yarl: 1.9.2 - yellowbrick: 1.5 - yfinance: 0.2.31 - zict: 3.0.0 - zipp: 3.17.0 * System: - OS: Linux - architecture: - 64bit - ELF - processor: x86_64 - python: 3.10.12 - release: 5.15.120+ - version: #1 SMP Wed Aug 30 11:19:59 UTC 2023

More info

No response

cc @borda @carmocca @justusschock @awaelchli

carmocca commented 10 months ago

This is caused by this line: https://github.com/Lightning-AI/lightning/blob/master/src/lightning/fabric/plugins/precision/half.py#L42

We could:

I don't see a perfect solution here, there will always be edge cases. What's your opinion @awaelchli?

awaelchli commented 8 months ago

In the provided code, the user chose precision="bf16-true" which is the explicit way of saying "I want everything in bfloat16", and this is what Lightning does. Excluding the buffers by default would be a very arbitrary choice for the framework to do. Besides, if this were the case then the user would likely have to change their code in forward where they use the buffers.

In our documentation, we could make this clear by mentioning the buffers in this sentence: https://lightning.ai/docs/pytorch/stable/common/precision_intermediate.html#true-half-precision

If there is a strong desire to exclude buffers, one could add an flag to the HalfPrecision plugin:

trainer = Trainer(precision=HalfPrecision(buffers=False))
carmocca commented 8 months ago

Adding a flag makes sense to me. Similarly how one could want to control the output dtype which would also be a flag