Closed pfeatherstone closed 8 months ago
@pfeatherstone This is because the input is float32, yet the model weights are in float16. To fix this, you need to create the input with the right precision:
x = torch.randn(10, 32, dtype=torch.bfloat16)
I don't think this can be handled automatically by Lightning.
Bug description
Trying to export a model with BFloat16 weights breaks. Usually, lightning transfers the inputs to the correct device, and i thought, the correct dtype. That might not be the case.
What version are you seeing the problem on?
v2.1
How to reproduce the bug
Error messages and logs
Environment
Current environment
* CUDA: - GPU: - NVIDIA RTX A6000 - available: True - version: 12.1 * Lightning: - lightning: 2.1.3 - lightning-utilities: 0.10.0 - pytorch-lightning: 2.1.2 - recurrent-memory-transformer-pytorch: 0.5.5 - torch: 2.1.2 - torchaudio: 2.1.2 - torchmetrics: 1.2.0 - torchvision: 0.16.2 * Packages: - absl-py: 2.0.0 - aiofiles: 23.2.1 - aiohttp: 3.9.1 - aiosignal: 1.3.1 - argparse: 1.4.0 - async-timeout: 4.0.3 - attrs: 23.1.0 - cachetools: 5.3.2 - certifi: 2023.11.17 - charset-normalizer: 3.3.2 - coloredlogs: 15.0.1 - contourpy: 1.2.0 - cycler: 0.12.1 - datasets: 2.16.1 - dill: 0.3.7 - einops: 0.7.0 - filelock: 3.13.1 - flashlight: 0.1.1 - flashlight-text: 0.0.4 - flatbuffers: 23.5.26 - fonttools: 4.45.1 - frozenlist: 1.4.0 - fsspec: 2023.10.0 - future: 0.18.3 - google-auth: 2.23.4 - google-auth-oauthlib: 1.1.0 - grpcio: 1.59.3 - httptools: 0.6.1 - huggingface-hub: 0.20.1 - humanfriendly: 10.0 - idna: 3.6 - jaxtyping: 0.2.23 - jinja2: 3.1.2 - kiwisolver: 1.4.5 - lightning: 2.1.3 - lightning-utilities: 0.10.0 - llvmlite: 0.41.1 - mako: 1.3.0 - markdown: 3.5.1 - markdown-it-py: 3.0.0 - markupsafe: 2.1.3 - matplotlib: 3.8.2 - mdurl: 0.1.2 - mpmath: 1.3.0 - multidict: 6.0.4 - multiprocess: 0.70.15 - networkx: 3.2.1 - numba: 0.58.1 - numpy: 1.26.2 - nvidia-cublas-cu12: 12.1.3.1 - nvidia-cuda-cupti-cu12: 12.1.105 - nvidia-cuda-nvrtc-cu12: 12.1.105 - nvidia-cuda-runtime-cu12: 12.1.105 - nvidia-cudnn-cu12: 8.9.2.26 - nvidia-cufft-cu12: 11.0.2.54 - nvidia-curand-cu12: 10.3.2.106 - nvidia-cusolver-cu12: 11.4.5.107 - nvidia-cusparse-cu12: 12.1.0.106 - nvidia-nccl-cu12: 2.18.1 - nvidia-nvjitlink-cu12: 12.3.101 - nvidia-nvtx-cu12: 12.1.105 - oauthlib: 3.2.2 - onnx: 1.15.0 - onnxruntime: 1.16.3 - onnxruntime-gpu: 1.16.3 - onnxscript: 0.1.0.dev20231213 - onnxsim: 0.4.35 - packaging: 23.2 - pandas: 2.1.4 - pillow: 10.1.0 - pip: 22.0.2 - pipe: 2.0 - protobuf: 4.23.4 - pyarrow: 14.0.2 - pyarrow-hotfix: 0.6 - pyasn1: 0.5.1 - pyasn1-modules: 0.3.0 - pybombs: 2.3.5 - pygments: 2.17.2 - pyparsing: 3.1.1 - pyqt5: 5.15.10 - pyqt5-qt5: 5.15.2 - pyqt5-sip: 12.13.0 - pyqtgraph: 0.13.3 - python-dateutil: 2.8.2 - pytorch-lightning: 2.1.2 - pytz: 2023.3.post1 - pyyaml: 6.0.1 - recurrent-memory-transformer-pytorch: 0.5.5 - requests: 2.31.0 - requests-oauthlib: 1.3.1 - rich: 13.7.0 - rsa: 4.9 - ruamel.yaml: 0.18.5 - ruamel.yaml.clib: 0.2.8 - sanic: 0.7.0 - scipy: 1.11.4 - setuptools: 59.6.0 - six: 1.16.0 - sympy: 1.12 - tensorboard: 2.15.1 - tensorboard-data-server: 0.7.2 - torch: 2.1.2 - torchaudio: 2.1.2 - torchmetrics: 1.2.0 - torchvision: 0.16.2 - tqdm: 4.66.1 - triton: 2.1.0 - typeguard: 2.13.3 - typing-extensions: 4.8.0 - tzdata: 2023.4 - uhd: 4.6.0 - ujson: 5.9.0 - urllib3: 2.1.0 - uvloop: 0.19.0 - websockets: 12.0 - werkzeug: 3.0.1 - x-transformers: 1.27.9 - xxhash: 3.4.1 - yarl: 1.9.3 * System: - OS: Linux - architecture: - 64bit - ELF - processor: x86_64 - python: 3.10.12 - release: 6.2.0-33-generic - version: #33~22.04.1-Ubuntu SMP PREEMPT_DYNAMIC Thu Sep 7 10:33:52 UTC 2More info
No response