Closed erl61 closed 1 year ago
@erl61 Thanks for reporting. The MPS backend in PyTorch doesn't support an int64 tensor as an input to the index_add()
. This happens during backward, and there is a layer in your model that probably used int64 as the data type somewhere. I suggest that you report it to PyTorch directly. I don't see how Lightning could do anything about that. The MPS backend does not support all operations and data types (yet).
Bug description
I am trying to run "Demand forecasting with the Temporal Fusion Transformer" from PyTorch-Forecasting tutorial. It works perfect with accelerator="cpu". But when I change it to accelerator="mps" it shows "RuntimeError: index_add(): Expected non int64 dtype for source." The error happens when I call lightning.pytorch.Trainer. I have MacBook M1 Pro and tried different versions of Python, PyTorch, PyTorch-Forecasting and PyTorch-Lightning. torch.backends.mps.is_available() shows "True".
What version are you seeing the problem on?
v2.0, master
How to reproduce the bug
Error messages and logs
Environment
Current environment
* CUDA: - GPU: None - available: False - version: None * Lightning: - lightning: 2.0.2 - lightning-cloud: 0.5.36 - lightning-utilities: 0.8.0 - pytorch-forecasting: 1.0.0 - pytorch-lightning: 2.0.2 - pytorch-optimizer: 2.9.1 - torch: 2.0.1 - torchmetrics: 0.11.4 * Packages: - aiohttp: 3.8.4 - aiosignal: 1.3.1 - alembic: 1.11.1 - altgraph: 0.17.2 - anyio: 3.6.2 - appnope: 0.1.3 - arrow: 1.2.3 - asttokens: 2.2.1 - async-timeout: 4.0.2 - attrs: 23.1.0 - backcall: 0.2.0 - beautifulsoup4: 4.12.2 - blessed: 1.20.0 - certifi: 2023.5.7 - charset-normalizer: 3.1.0 - click: 8.1.3 - cmaes: 0.9.1 - colorlog: 6.7.0 - comm: 0.1.3 - contourpy: 1.0.7 - cramjam: 2.6.2 - croniter: 1.3.14 - cycler: 0.11.0 - datasets: 2.12.0 - dateutils: 0.6.12 - debugpy: 1.6.7 - decorator: 5.1.1 - deepdiff: 6.3.0 - dill: 0.3.6 - executing: 1.2.0 - fastapi: 0.88.0 - fastparquet: 2023.4.0 - filelock: 3.12.0 - fonttools: 4.39.4 - frozenlist: 1.3.3 - fsspec: 2023.5.0 - future: 0.18.2 - h11: 0.14.0 - huggingface-hub: 0.14.1 - idna: 3.4 - importlib-metadata: 6.6.0 - importlib-resources: 5.12.0 - inquirer: 3.1.3 - ipykernel: 6.23.1 - ipython: 8.13.2 - itsdangerous: 2.1.2 - jedi: 0.18.2 - jinja2: 3.1.2 - joblib: 1.2.0 - jupyter-client: 8.2.0 - jupyter-core: 5.3.0 - kiwisolver: 1.4.4 - lightning: 2.0.2 - lightning-cloud: 0.5.36 - lightning-utilities: 0.8.0 - macholib: 1.15.2 - mako: 1.2.4 - markdown-it-py: 2.2.0 - markupsafe: 2.1.2 - matplotlib: 3.7.1 - matplotlib-inline: 0.1.6 - mdurl: 0.1.2 - mpmath: 1.3.0 - multidict: 6.0.4 - multiprocess: 0.70.14 - nest-asyncio: 1.5.6 - networkx: 3.1 - numpy: 1.24.3 - optuna: 3.1.1 - ordered-set: 4.1.0 - packaging: 23.1 - pandas: 2.0.1 - parso: 0.8.3 - patsy: 0.5.3 - pexpect: 4.8.0 - pickleshare: 0.7.5 - pillow: 9.5.0 - pip: 23.1.2 - platformdirs: 3.5.1 - prompt-toolkit: 3.0.38 - psutil: 5.9.5 - ptyprocess: 0.7.0 - pure-eval: 0.2.2 - pyarrow: 12.0.0 - pydantic: 1.10.7 - pygments: 2.15.1 - pyjwt: 2.7.0 - pyparsing: 3.0.9 - python-dateutil: 2.8.2 - python-editor: 1.0.4 - python-multipart: 0.0.6 - pytorch-forecasting: 1.0.0 - pytorch-lightning: 2.0.2 - pytorch-optimizer: 2.9.1 - pytz: 2023.3 - pyyaml: 6.0 - pyzmq: 25.0.2 - readchar: 4.0.5 - regex: 2023.5.5 - requests: 2.31.0 - responses: 0.18.0 - rich: 13.3.5 - scikit-learn: 1.2.2 - scipy: 1.10.1 - setuptools: 58.0.4 - six: 1.15.0 - sniffio: 1.3.0 - soupsieve: 2.4.1 - sqlalchemy: 2.0.15 - stack-data: 0.6.2 - starlette: 0.22.0 - starsessions: 1.3.0 - statsmodels: 0.14.0 - sympy: 1.12 - threadpoolctl: 3.1.0 - tokenizers: 0.13.3 - torch: 2.0.1 - torchmetrics: 0.11.4 - tornado: 6.3.2 - tqdm: 4.65.0 - traitlets: 5.9.0 - transformers: 4.29.2 - typing-extensions: 4.5.0 - tzdata: 2023.3 - urllib3: 1.26.6 - uvicorn: 0.22.0 - wcwidth: 0.2.6 - websocket-client: 1.5.2 - websockets: 11.0.3 - wheel: 0.37.0 - xxhash: 3.2.0 - yarl: 1.9.2 - zipp: 3.15.0 * System: - OS: Darwin - architecture: - 64bit - - processor: arm - python: 3.9.6 - release: 22.4.0 - version: Darwin Kernel Version 22.4.0: Mon Mar 6 20:59:28 PST 2023; root:xnu-8796.101.5~3[/RELEASE_ARM64_T6000](https://file+.vscode-resource.vscode-cdn.net/RELEASE_ARM64_T6000)More info
No response
cc @justusschock