Open MF-FOOM opened 10 months ago
This is caused by this line: https://github.com/Lightning-AI/lightning/blob/master/src/lightning/fabric/plugins/precision/half.py#L42
We could:
torch.set_default_dtype
?I don't see a perfect solution here, there will always be edge cases. What's your opinion @awaelchli?
In the provided code, the user chose precision="bf16-true"
which is the explicit way of saying "I want everything in bfloat16", and this is what Lightning does. Excluding the buffers by default would be a very arbitrary choice for the framework to do. Besides, if this were the case then the user would likely have to change their code in forward where they use the buffers.
In our documentation, we could make this clear by mentioning the buffers in this sentence: https://lightning.ai/docs/pytorch/stable/common/precision_intermediate.html#true-half-precision
If there is a strong desire to exclude buffers, one could add an flag to the HalfPrecision plugin:
trainer = Trainer(precision=HalfPrecision(buffers=False))
Adding a flag makes sense to me. Similarly how one could want to control the output dtype which would also be a flag
Bug description
If I register a
float64
tensor to a buffer in the__init__
function of aLightningModule
like so:It will get cast into the
Trainer
'sprecision
type after setup, regardless of whether a different dtype (e.g.float64
) was intended.What version are you seeing the problem on?
v2.1
How to reproduce the bug
https://colab.research.google.com/drive/1UN1JeuiV3LM-LRrATigboJeDH9ehJifY
Error messages and logs
N/A
Environment
Current environment
* CUDA: - GPU: - Tesla T4 - available: True - version: 11.8 * Lightning: - lightning: 2.1.1 - lightning-utilities: 0.9.0 - pytorch-lightning: 2.1.1 - torch: 2.1.0+cu118 - torchaudio: 2.1.0+cu118 - torchdata: 0.7.0 - torchmetrics: 1.2.0 - torchsummary: 1.5.1 - torchtext: 0.16.0 - torchvision: 0.16.0+cu118 * Packages: - absl-py: 1.4.0 - aiohttp: 3.8.6 - aiosignal: 1.3.1 - alabaster: 0.7.13 - albumentations: 1.3.1 - altair: 4.2.2 - anyio: 3.7.1 - appdirs: 1.4.4 - argon2-cffi: 23.1.0 - argon2-cffi-bindings: 21.2.0 - array-record: 0.5.0 - arviz: 0.15.1 - astropy: 5.3.4 - astunparse: 1.6.3 - async-timeout: 4.0.3 - atpublic: 4.0 - attrs: 23.1.0 - audioread: 3.0.1 - autograd: 1.6.2 - babel: 2.13.1 - backcall: 0.2.0 - beautifulsoup4: 4.11.2 - bidict: 0.22.1 - bigframes: 0.12.0 - bleach: 6.1.0 - blinker: 1.4 - blis: 0.7.11 - blosc2: 2.0.0 - bokeh: 3.3.0 - bqplot: 0.12.42 - branca: 0.7.0 - build: 1.0.3 - cachecontrol: 0.13.1 - cachetools: 5.3.2 - catalogue: 2.0.10 - certifi: 2023.7.22 - cffi: 1.16.0 - chardet: 5.2.0 - charset-normalizer: 3.3.2 - chex: 0.1.7 - click: 8.1.7 - click-plugins: 1.1.1 - cligj: 0.7.2 - cloudpickle: 2.2.1 - cmake: 3.27.7 - cmdstanpy: 1.2.0 - colorcet: 3.0.1 - colorlover: 0.3.0 - colour: 0.1.5 - community: 1.0.0b1 - confection: 0.1.3 - cons: 0.4.6 - contextlib2: 21.6.0 - contourpy: 1.2.0 - cryptography: 41.0.5 - cufflinks: 0.17.3 - cupy-cuda11x: 11.0.0 - cvxopt: 1.3.2 - cvxpy: 1.3.2 - cycler: 0.12.1 - cymem: 2.0.8 - cython: 3.0.5 - dask: 2023.8.1 - datascience: 0.17.6 - db-dtypes: 1.1.1 - dbus-python: 1.2.18 - debugpy: 1.6.6 - decorator: 4.4.2 - defusedxml: 0.7.1 - diskcache: 5.6.3 - distributed: 2023.8.1 - distro: 1.7.0 - dlib: 19.24.2 - dm-tree: 0.1.8 - docutils: 0.18.1 - dopamine-rl: 4.0.6 - duckdb: 0.9.1 - earthengine-api: 0.1.377 - easydict: 1.11 - ecos: 2.0.12 - editdistance: 0.6.2 - eerepr: 0.0.4 - en-core-web-sm: 3.6.0 - entrypoints: 0.4 - et-xmlfile: 1.1.0 - etils: 1.5.2 - etuples: 0.3.9 - exceptiongroup: 1.1.3 - fastai: 2.7.13 - fastcore: 1.5.29 - fastdownload: 0.0.7 - fastjsonschema: 2.18.1 - fastprogress: 1.0.3 - fastrlock: 0.8.2 - filelock: 3.13.1 - fiona: 1.9.5 - firebase-admin: 5.3.0 - flask: 2.2.5 - flatbuffers: 23.5.26 - flax: 0.7.5 - folium: 0.14.0 - fonttools: 4.44.0 - frozendict: 2.3.8 - frozenlist: 1.4.0 - fsspec: 2023.6.0 - future: 0.18.3 - gast: 0.5.4 - gcsfs: 2023.6.0 - gdal: 3.4.3 - gdown: 4.6.6 - geemap: 0.28.2 - gensim: 4.3.2 - geocoder: 1.38.1 - geographiclib: 2.0 - geopandas: 0.13.2 - geopy: 2.3.0 - gin-config: 0.5.0 - glob2: 0.7 - google: 2.0.3 - google-api-core: 2.11.1 - google-api-python-client: 2.84.0 - google-auth: 2.17.3 - google-auth-httplib2: 0.1.1 - google-auth-oauthlib: 1.0.0 - google-cloud-bigquery: 3.12.0 - google-cloud-bigquery-connection: 1.12.1 - google-cloud-bigquery-storage: 2.22.0 - google-cloud-core: 2.3.3 - google-cloud-datastore: 2.15.2 - google-cloud-firestore: 2.11.1 - google-cloud-functions: 1.13.3 - google-cloud-iam: 2.12.2 - google-cloud-language: 2.9.1 - google-cloud-resource-manager: 1.10.4 - google-cloud-storage: 2.8.0 - google-cloud-translate: 3.11.3 - google-colab: 1.0.0 - google-crc32c: 1.5.0 - google-pasta: 0.2.0 - google-resumable-media: 2.6.0 - googleapis-common-protos: 1.61.0 - googledrivedownloader: 0.4 - graphviz: 0.20.1 - greenlet: 3.0.1 - grpc-google-iam-v1: 0.12.6 - grpcio: 1.59.2 - grpcio-status: 1.48.2 - gspread: 3.4.2 - gspread-dataframe: 3.3.1 - gym: 0.25.2 - gym-notices: 0.0.8 - h5netcdf: 1.3.0 - h5py: 3.9.0 - holidays: 0.36 - holoviews: 1.17.1 - html5lib: 1.1 - httpimport: 1.3.1 - httplib2: 0.22.0 - humanize: 4.7.0 - hyperopt: 0.2.7 - ibis-framework: 6.2.0 - idna: 3.4 - imageio: 2.31.6 - imageio-ffmpeg: 0.4.9 - imagesize: 1.4.1 - imbalanced-learn: 0.10.1 - imgaug: 0.4.0 - importlib-metadata: 6.8.0 - importlib-resources: 6.1.1 - imutils: 0.5.4 - inflect: 7.0.0 - iniconfig: 2.0.0 - install: 1.3.5 - intel-openmp: 2023.2.0 - ipyevents: 2.0.2 - ipyfilechooser: 0.6.0 - ipykernel: 5.5.6 - ipyleaflet: 0.17.4 - ipython: 7.34.0 - ipython-genutils: 0.2.0 - ipython-sql: 0.5.0 - ipytree: 0.2.2 - ipywidgets: 7.7.1 - itsdangerous: 2.1.2 - jax: 0.4.20 - jaxlib: 0.4.20+cuda11.cudnn86 - jeepney: 0.7.1 - jieba: 0.42.1 - jinja2: 3.1.2 - joblib: 1.3.2 - jsonpickle: 3.0.2 - jsonschema: 4.19.2 - jsonschema-specifications: 2023.7.1 - jupyter-client: 6.1.12 - jupyter-console: 6.1.0 - jupyter-core: 5.5.0 - jupyter-server: 1.24.0 - jupyterlab-pygments: 0.2.2 - jupyterlab-widgets: 3.0.9 - kaggle: 1.5.16 - keras: 2.14.0 - keyring: 23.5.0 - kiwisolver: 1.4.5 - langcodes: 3.3.0 - launchpadlib: 1.10.16 - lazr.restfulclient: 0.14.4 - lazr.uri: 1.0.6 - lazy-loader: 0.3 - libclang: 16.0.6 - librosa: 0.10.1 - lida: 0.0.10 - lightgbm: 4.1.0 - lightning: 2.1.1 - lightning-utilities: 0.9.0 - linkify-it-py: 2.0.2 - llmx: 0.0.15a0 - llvmlite: 0.41.1 - locket: 1.0.0 - logical-unification: 0.4.6 - lxml: 4.9.3 - malloy: 2023.1064 - markdown: 3.5.1 - markdown-it-py: 3.0.0 - markupsafe: 2.1.3 - matplotlib: 3.7.1 - matplotlib-inline: 0.1.6 - matplotlib-venn: 0.11.9 - mdit-py-plugins: 0.4.0 - mdurl: 0.1.2 - minikanren: 1.0.3 - missingno: 0.5.2 - mistune: 0.8.4 - mizani: 0.9.3 - mkl: 2023.2.0 - ml-dtypes: 0.2.0 - mlxtend: 0.22.0 - more-itertools: 10.1.0 - moviepy: 1.0.3 - mpmath: 1.3.0 - msgpack: 1.0.7 - multidict: 6.0.4 - multipledispatch: 1.0.0 - multitasking: 0.0.11 - murmurhash: 1.0.10 - music21: 9.1.0 - natsort: 8.4.0 - nbclassic: 1.0.0 - nbclient: 0.9.0 - nbconvert: 6.5.4 - nbformat: 5.9.2 - nest-asyncio: 1.5.8 - networkx: 3.2.1 - nibabel: 4.0.2 - nltk: 3.8.1 - notebook: 6.5.5 - notebook-shim: 0.2.3 - numba: 0.58.1 - numexpr: 2.8.7 - numpy: 1.23.5 - oauth2client: 4.1.3 - oauthlib: 3.2.2 - opencv-contrib-python: 4.8.0.76 - opencv-python: 4.8.0.76 - opencv-python-headless: 4.8.1.78 - openpyxl: 3.1.2 - opt-einsum: 3.3.0 - optax: 0.1.7 - orbax-checkpoint: 0.4.2 - osqp: 0.6.2.post8 - packaging: 23.2 - pandas: 1.5.3 - pandas-datareader: 0.10.0 - pandas-gbq: 0.17.9 - pandas-stubs: 1.5.3.230304 - pandocfilters: 1.5.0 - panel: 1.3.1 - param: 2.0.0 - parso: 0.8.3 - parsy: 2.1 - partd: 1.4.1 - pathlib: 1.0.1 - pathy: 0.10.3 - patsy: 0.5.3 - peewee: 3.17.0 - pexpect: 4.8.0 - pickleshare: 0.7.5 - pillow: 9.4.0 - pip: 23.1.2 - pip-tools: 6.13.0 - platformdirs: 3.11.0 - plotly: 5.15.0 - plotnine: 0.12.4 - pluggy: 1.3.0 - polars: 0.17.3 - pooch: 1.8.0 - portpicker: 1.5.2 - prefetch-generator: 1.0.3 - preshed: 3.0.9 - prettytable: 3.9.0 - proglog: 0.1.10 - progressbar2: 4.2.0 - prometheus-client: 0.18.0 - promise: 2.3 - prompt-toolkit: 3.0.39 - prophet: 1.1.5 - proto-plus: 1.22.3 - protobuf: 3.20.3 - psutil: 5.9.5 - psycopg2: 2.9.9 - ptyprocess: 0.7.0 - py-cpuinfo: 9.0.0 - py4j: 0.10.9.7 - pyarrow: 9.0.0 - pyasn1: 0.5.0 - pyasn1-modules: 0.3.0 - pycocotools: 2.0.7 - pycparser: 2.21 - pyct: 0.5.0 - pydantic: 1.10.13 - pydata-google-auth: 1.8.2 - pydot: 1.4.2 - pydot-ng: 2.0.0 - pydotplus: 2.0.2 - pydrive: 1.3.1 - pydrive2: 1.6.3 - pyerfa: 2.0.1.1 - pygame: 2.5.2 - pygments: 2.16.1 - pygobject: 3.42.1 - pyjwt: 2.3.0 - pymc: 5.7.2 - pymystem3: 0.2.0 - pyopengl: 3.1.7 - pyopenssl: 23.3.0 - pyparsing: 3.1.1 - pyperclip: 1.8.2 - pyproj: 3.6.1 - pyproject-hooks: 1.0.0 - pyshp: 2.3.1 - pysocks: 1.7.1 - pytensor: 2.14.2 - pytest: 7.4.3 - python-apt: 0.0.0 - python-box: 7.1.1 - python-dateutil: 2.8.2 - python-louvain: 0.16 - python-slugify: 8.0.1 - python-utils: 3.8.1 - pytorch-lightning: 2.1.1 - pytz: 2023.3.post1 - pyviz-comms: 3.0.0 - pywavelets: 1.4.1 - pyyaml: 6.0.1 - pyzmq: 23.2.1 - qdldl: 0.1.7.post0 - qudida: 0.0.4 - ratelim: 0.1.6 - referencing: 0.30.2 - regex: 2023.6.3 - requests: 2.31.0 - requests-oauthlib: 1.3.1 - requirements-parser: 0.5.0 - rich: 13.6.0 - rpds-py: 0.12.0 - rpy2: 3.4.2 - rsa: 4.9 - scikit-image: 0.19.3 - scikit-learn: 1.2.2 - scipy: 1.11.3 - scooby: 0.9.2 - scs: 3.2.3 - seaborn: 0.12.2 - secretstorage: 3.3.1 - send2trash: 1.8.2 - setuptools: 67.7.2 - shapely: 2.0.2 - six: 1.16.0 - sklearn-pandas: 2.2.0 - smart-open: 6.4.0 - sniffio: 1.3.0 - snowballstemmer: 2.2.0 - sortedcontainers: 2.4.0 - soundfile: 0.12.1 - soupsieve: 2.5 - soxr: 0.3.7 - spacy: 3.6.1 - spacy-legacy: 3.0.12 - spacy-loggers: 1.0.5 - sphinx: 5.0.2 - sphinxcontrib-applehelp: 1.0.7 - sphinxcontrib-devhelp: 1.0.5 - sphinxcontrib-htmlhelp: 2.0.4 - sphinxcontrib-jsmath: 1.0.1 - sphinxcontrib-qthelp: 1.0.6 - sphinxcontrib-serializinghtml: 1.1.9 - sqlalchemy: 2.0.23 - sqlglot: 17.16.2 - sqlparse: 0.4.4 - srsly: 2.4.8 - stanio: 0.3.0 - statsmodels: 0.14.0 - sympy: 1.12 - tables: 3.8.0 - tabulate: 0.9.0 - tbb: 2021.10.0 - tblib: 3.0.0 - tenacity: 8.2.3 - tensorboard: 2.14.1 - tensorboard-data-server: 0.7.2 - tensorflow: 2.14.0 - tensorflow-datasets: 4.9.3 - tensorflow-estimator: 2.14.0 - tensorflow-gcs-config: 2.14.0 - tensorflow-hub: 0.15.0 - tensorflow-io-gcs-filesystem: 0.34.0 - tensorflow-metadata: 1.14.0 - tensorflow-probability: 0.22.0 - tensorstore: 0.1.45 - termcolor: 2.3.0 - terminado: 0.17.1 - text-unidecode: 1.3 - textblob: 0.17.1 - tf-slim: 1.1.0 - thinc: 8.1.12 - threadpoolctl: 3.2.0 - tifffile: 2023.9.26 - tinycss2: 1.2.1 - toml: 0.10.2 - tomli: 2.0.1 - toolz: 0.12.0 - torch: 2.1.0+cu118 - torchaudio: 2.1.0+cu118 - torchdata: 0.7.0 - torchmetrics: 1.2.0 - torchsummary: 1.5.1 - torchtext: 0.16.0 - torchvision: 0.16.0+cu118 - tornado: 6.3.2 - tqdm: 4.66.1 - traitlets: 5.7.1 - traittypes: 0.2.1 - triton: 2.1.0 - tweepy: 4.14.0 - typer: 0.9.0 - types-pytz: 2023.3.1.1 - types-setuptools: 68.2.0.0 - typing-extensions: 4.5.0 - tzlocal: 5.2 - uc-micro-py: 1.0.2 - uritemplate: 4.1.1 - urllib3: 2.0.7 - vega-datasets: 0.9.0 - wadllib: 1.3.6 - wasabi: 1.1.2 - wcwidth: 0.2.9 - webcolors: 1.13 - webencodings: 0.5.1 - websocket-client: 1.6.4 - werkzeug: 3.0.1 - wheel: 0.41.3 - widgetsnbextension: 3.6.6 - wordcloud: 1.9.2 - wrapt: 1.14.1 - xarray: 2023.7.0 - xarray-einstats: 0.6.0 - xgboost: 2.0.1 - xlrd: 2.0.1 - xxhash: 3.4.1 - xyzservices: 2023.10.1 - yarl: 1.9.2 - yellowbrick: 1.5 - yfinance: 0.2.31 - zict: 3.0.0 - zipp: 3.17.0 * System: - OS: Linux - architecture: - 64bit - ELF - processor: x86_64 - python: 3.10.12 - release: 5.15.120+ - version: #1 SMP Wed Aug 30 11:19:59 UTC 2023More info
No response
cc @borda @carmocca @justusschock @awaelchli