Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.19k stars 3.38k forks source link

Fix for error trying to use MPS on Intel MacBook. #17667

Open mwwconsulting opened 1 year ago

mwwconsulting commented 1 year ago

Bug description

As of, I believe, macOS 12.3, MPS is supported on Intel, not just ARM. According to this, MPS supports M1s and AMD GPUs. This means that the assumption made in lightning/src/../accelerators/mps.py, namely that the processor must be arm, otherwise we're in emulation mode, is incorrect. When I commented out the last clause, my Intel Mac ran the code that I had been running on my M1 Mac.

 def is_available() -> bool:
        """MPS is only available for certain torch builds starting at torch>=1.12, and is only enabled on a machine
        with the ARM-based Apple Silicon processors."""
        return (
            _TORCH_GREATER_EQUAL_1_12 and torch.backends.mps.is_available() and platform.processor() in ("arm", "arm64")
        )

Changed to

 def is_available() -> bool:
        """MPS is only available for certain torch builds starting at torch>=1.12, and is only enabled on a machine
        with the ARM-based Apple Silicon processors."""
        return (
            _TORCH_GREATER_EQUAL_1_12 and torch.backends.mps.is_available()
        )

Is it safe to simply comment out the last part and assume that if torch.backend.mps.is_available(), then we're in the clear? I'm happy to run a test suite to test if my change is safe, I'm just not sure which one to run.

What version are you seeing the problem on?

2.0.2

How to reproduce the bug

trainer = pl.Trainer(accelerator="mps", devices="1", precision="16-mixed", max_epochs=75)

Error messages and logs

# Error messages and logs here please

Environment

Current environment * CUDA: - GPU: None - available: False - version: None * Lightning: - lightning: 2.0.2 - lightning-cloud: 0.5.36 - lightning-utilities: 0.8.0 - pytorch-lightning: 2.0.2 - torch: 2.0.1 - torchmetrics: 0.11.4 * Packages: - absl-py: 1.4.0 - aiohttp: 3.8.4 - aiohttp-cors: 0.7.0 - aiorwlock: 1.3.0 - aiosignal: 1.3.1 - alabaster: 0.7.12 - alembic: 1.10.4 - anyio: 3.6.2 - appdirs: 1.4.4 - applaunchservices: 0.3.0 - appnope: 0.1.2 - appscript: 1.1.2 - apsw: 3.41.2.0 - argon2-cffi: 21.3.0 - argon2-cffi-bindings: 21.2.0 - arrow: 1.2.3 - astroid: 2.14.2 - astropy: 5.1 - asttokens: 2.0.5 - astunparse: 1.6.3 - async-timeout: 4.0.2 - atomicwrites: 1.4.0 - attrs: 23.1.0 - automat: 20.2.0 - autopep8: 1.6.0 - babel: 2.11.0 - backcall: 0.2.0 - backoff: 2.2.1 - bcrypt: 3.2.0 - beautifulsoup4: 4.12.2 - binaryornot: 0.4.4 - black: 0.0 - bleach: 4.1.0 - blessed: 1.20.0 - bokeh: 2.4.3 - boto3: 1.26.133 - botocore: 1.29.133 - bottleneck: 1.3.5 - bravado: 11.0.3 - bravado-core: 5.17.1 - brotlipy: 0.7.0 - cachetools: 5.3.0 - certifi: 2023.5.7 - cffi: 1.15.1 - chardet: 4.0.0 - charset-normalizer: 3.1.0 - chess: 1.9.4 - click: 8.1.3 - cloudpickle: 2.2.1 - cmaes: 0.9.1 - colorama: 0.4.6 - colorcet: 3.0.1 - colorful: 0.5.5 - colorlog: 6.7.0 - comet-ml: 3.33.2 - comm: 0.1.2 - configobj: 5.0.8 - connectorx: 0.3.1 - constantly: 15.1.0 - contourpy: 1.0.5 - cookiecutter: 1.7.3 - croniter: 1.3.14 - cryptography: 39.0.1 - cssselect: 1.1.0 - cycler: 0.11.0 - cython: 0.29.34 - cytoolz: 0.12.0 - daal4py: 2023.1.1 - dask: 2023.4.1 - datashader: 0.14.4 - datashape: 0.5.4 - dateutils: 0.6.12 - debugpy: 1.5.1 - decorator: 5.1.1 - deepdiff: 6.3.0 - defusedxml: 0.7.1 - deprecated: 1.2.13 - diff-match-patch: 20200713 - dill: 0.3.6 - distlib: 0.3.6 - distributed: 2023.4.1 - dm-tree: 0.1.8 - docker-pycreds: 0.4.0 - docstring-to-markdown: 0.11 - docutils: 0.18.1 - duckdb: 0.8.0 - dulwich: 0.21.5 - endgame-nn: 0.1.0 - entrypoints: 0.4 - et-xmlfile: 1.1.0 - everett: 3.1.0 - executing: 0.8.3 - fastapi: 0.88.0 - fastjsonschema: 2.16.2 - filelock: 3.12.0 - flake8: 6.0.0 - flask: 2.2.2 - flatbuffers: 23.5.9 - fonttools: 4.25.0 - fqdn: 1.5.1 - frozenlist: 1.3.3 - fsspec: 2023.5.0 - future: 0.18.3 - gast: 0.4.0 - gensim: 4.3.0 - gitdb: 4.0.10 - gitpython: 3.1.31 - gmpy2: 2.1.2 - google-api-core: 2.11.0 - google-auth: 2.18.0 - google-auth-oauthlib: 1.0.0 - google-pasta: 0.2.0 - googleapis-common-protos: 1.59.0 - gpustat: 1.1 - greenlet: 2.0.1 - grpcio: 1.49.1 - gymnasium: 0.26.3 - gymnasium-notices: 0.0.1 - h11: 0.14.0 - h5py: 3.7.0 - heapdict: 1.0.1 - holoviews: 1.15.4 - huggingface-hub: 0.10.1 - hvplot: 0.8.2 - hyperlink: 21.0.0 - idna: 3.4 - imagecodecs: 2021.8.26 - imageio: 2.26.0 - imagesize: 1.4.1 - imbalanced-learn: 0.10.1 - importlib-metadata: 6.0.0 - incremental: 21.3.0 - inflection: 0.5.1 - iniconfig: 1.1.1 - inquirer: 3.1.3 - intake: 0.6.8 - intel-extension-for-tensorflow: 0.0.0.dev1 - intervaltree: 3.1.0 - ipykernel: 6.19.2 - ipython: 8.12.0 - ipython-genutils: 0.2.0 - ipywidgets: 8.0.4 - isoduration: 20.11.0 - isort: 5.9.3 - itemadapter: 0.3.0 - itemloaders: 1.0.4 - itsdangerous: 2.1.2 - jaraco.classes: 3.2.1 - jax: 0.4.10 - jedi: 0.18.1 - jellyfish: 0.9.0 - jinja2: 3.1.2 - jinja2-time: 0.2.0 - jmespath: 0.10.0 - joblib: 1.1.1 - json5: 0.9.6 - jsonpointer: 2.3 - jsonref: 1.1.0 - jsonschema: 4.17.3 - jupyter: 1.0.0 - jupyter-client: 8.1.0 - jupyter-console: 6.6.3 - jupyter-core: 5.3.0 - jupyter-server: 1.23.4 - jupyterlab: 3.5.3 - jupyterlab-pygments: 0.1.2 - jupyterlab-server: 2.22.0 - jupyterlab-widgets: 3.0.5 - keras: 2.12.0 - keyring: 23.13.1 - kiwisolver: 1.4.4 - kubernetes: 26.1.0 - lazy-loader: 0.1 - lazy-object-proxy: 1.6.0 - libclang: 16.0.0 - lightning: 2.0.2 - lightning-cloud: 0.5.36 - lightning-utilities: 0.8.0 - llvmlite: 0.40.0 - lmdb: 1.4.1 - locket: 1.0.0 - lxml: 4.9.2 - lz4: 4.3.2 - mako: 1.2.4 - markdown: 3.4.1 - markdown-it-py: 2.2.0 - markupsafe: 2.1.2 - matplotlib: 3.7.1 - matplotlib-inline: 0.1.6 - maturin: 0.15.1 - mccabe: 0.7.0 - mdurl: 0.1.2 - mistune: 0.8.4 - ml-dtypes: 0.1.0 - mock: 4.0.3 - modin: 0.20.1 - monotonic: 1.6 - more-itertools: 8.12.0 - mpmath: 1.3.0 - msgpack: 1.0.3 - multidict: 6.0.4 - multipledispatch: 0.6.0 - munkres: 1.1.4 - mypy-extensions: 0.4.3 - nbclassic: 0.5.5 - nbclient: 0.5.13 - nbconvert: 6.5.4 - nbformat: 5.7.0 - neptune: 1.2.0 - nest-asyncio: 1.5.6 - networkx: 3.1 - nltk: 3.7 - notebook: 6.5.4 - notebook-shim: 0.2.2 - numba: 0.57.0 - numexpr: 2.8.4 - numpy: 1.24.3 - numpydoc: 1.5.0 - nvidia-ml-py: 11.525.112 - oauthlib: 3.2.2 - opencensus: 0.11.2 - opencensus-context: 0.1.3 - openpyxl: 3.0.10 - opentelemetry-api: 1.17.0 - opentelemetry-exporter-otlp: 1.17.0 - opentelemetry-exporter-otlp-proto-grpc: 1.17.0 - opentelemetry-exporter-otlp-proto-http: 1.17.0 - opentelemetry-proto: 1.17.0 - opentelemetry-sdk: 1.17.0 - opentelemetry-semantic-conventions: 0.38b0 - opt-einsum: 3.3.0 - optuna: 3.1.1 - ordered-set: 4.1.0 - packaging: 23.1 - pandas: 1.5.3 - pandocfilters: 1.5.0 - panel: 0.14.3 - param: 1.12.3 - parsel: 1.6.0 - parso: 0.8.3 - partd: 1.2.0 - pathspec: 0.10.3 - pathtools: 0.1.2 - patsy: 0.5.3 - peewee: 3.16.2 - pep8: 1.7.1 - pexpect: 4.8.0 - pickleshare: 0.7.5 - pillow: 9.4.0 - pip: 23.0.1 - platformdirs: 2.5.2 - plotly: 5.9.0 - pluggy: 1.0.0 - ply: 3.11 - polars: 0.17.13 - pooch: 1.4.0 - poyo: 0.5.0 - prometheus-client: 0.14.1 - prompt-toolkit: 3.0.36 - protego: 0.1.16 - protobuf: 4.23.0 - psutil: 5.9.5 - ptyprocess: 0.7.0 - pure-eval: 0.2.2 - py: 1.11.0 - py-spy: 0.3.14 - pyarrow: 11.0.0 - pyasn1: 0.4.8 - pyasn1-modules: 0.2.8 - pycodestyle: 2.10.0 - pycparser: 2.21 - pyct: 0.5.0 - pycurl: 7.45.2 - pydantic: 1.10.7 - pydispatcher: 2.0.5 - pydocstyle: 6.3.0 - pyerfa: 2.0.0 - pyflakes: 3.0.1 - pygments: 2.15.1 - pyjwt: 2.7.0 - pylint: 2.16.2 - pylint-venv: 2.3.0 - pyls-spyder: 0.4.0 - pyobjc-core: 9.0 - pyobjc-framework-cocoa: 9.0 - pyobjc-framework-coreservices: 9.0 - pyobjc-framework-fsevents: 9.0 - pyodbc: 4.0.34 - pyopenssl: 23.0.0 - pyparsing: 3.0.9 - pyqt5-sip: 12.11.0 - pyrsistent: 0.18.0 - pysocks: 1.7.1 - pytest: 7.1.2 - python-box: 6.1.0 - python-dateutil: 2.8.2 - python-editor: 1.0.4 - python-lsp-black: 1.2.1 - python-lsp-jsonrpc: 1.0.0 - python-lsp-server: 1.7.2 - python-multipart: 0.0.6 - python-slugify: 5.0.2 - python-snappy: 0.6.1 - pytoolconfig: 1.2.5 - pytorch-lightning: 2.0.2 - pytz: 2023.3 - pyviz-comms: 2.0.2 - pywavelets: 1.4.1 - pyyaml: 6.0 - pyzmq: 25.0.2 - qdarkstyle: 3.0.2 - qstylizer: 0.2.2 - qtawesome: 1.2.2 - qtconsole: 5.4.2 - qtpy: 2.2.0 - queuelib: 1.5.0 - ray: 2.4.0 - ray-cpp: 2.4.0 - readchar: 4.0.5 - regex: 2022.7.9 - requests: 2.30.0 - requests-file: 1.5.1 - requests-oauthlib: 1.3.1 - requests-toolbelt: 1.0.0 - rfc3339-validator: 0.1.4 - rfc3987: 1.3.8 - rich: 13.3.5 - rope: 1.7.0 - rsa: 4.9 - rtree: 1.0.1 - s3transfer: 0.6.1 - scikit-image: 0.20.0 - scikit-learn: 1.2.2 - scikit-learn-intelex: 20230426.61733 - scipy: 1.10.1 - scrapy: 2.8.0 - seaborn: 0.12.2 - semantic-version: 2.10.0 - send2trash: 1.8.0 - sentry-sdk: 1.22.2 - service-identity: 18.1.0 - setproctitle: 1.3.2 - setuptools: 67.7.2 - simplejson: 3.19.1 - sip: 6.6.2 - six: 1.16.0 - smart-open: 5.2.1 - smmap: 5.0.0 - sniffio: 1.3.0 - snowballstemmer: 2.2.0 - sortedcontainers: 2.4.0 - soupsieve: 2.4.1 - sphinx: 5.0.2 - sphinxcontrib-applehelp: 1.0.2 - sphinxcontrib-devhelp: 1.0.2 - sphinxcontrib-htmlhelp: 2.0.0 - sphinxcontrib-jsmath: 1.0.1 - sphinxcontrib-qthelp: 1.0.3 - sphinxcontrib-serializinghtml: 1.1.5 - spyder: 5.4.3 - spyder-kernels: 2.4.3 - sqlalchemy: 1.4.39 - stack-data: 0.2.0 - starlette: 0.22.0 - starsessions: 1.3.0 - statsmodels: 0.13.5 - swagger-spec-validator: 3.0.3 - sympy: 1.12 - tables: 3.7.0 - tabulate: 0.8.10 - tbb: 0.2 - tblib: 1.7.0 - tenacity: 8.2.2 - tensorboard: 2.12.3 - tensorboard-data-server: 0.7.0 - tensorboardx: 2.6 - tensorflow: 2.12.0 - tensorflow-estimator: 2.12.0 - tensorflow-io-gcs-filesystem: 0.32.0 - termcolor: 2.3.0 - terminado: 0.17.1 - text-unidecode: 1.3 - textdistance: 4.2.1 - threadpoolctl: 2.2.0 - three-merge: 0.1.1 - tifffile: 2021.7.2 - tinycss2: 1.2.1 - tldextract: 3.2.0 - tokenizers: 0.11.4 - toml: 0.10.2 - tomli: 2.0.1 - tomlkit: 0.11.1 - toolz: 0.12.0 - torch: 2.0.1 - torchmetrics: 0.11.4 - tornado: 6.2 - tqdm: 4.65.0 - traitlets: 5.9.0 - transformers: 4.24.0 - twisted: 22.10.0 - typer: 0.9.0 - typing-extensions: 4.5.0 - ujson: 5.4.0 - unidecode: 1.2.0 - uri-template: 1.2.0 - urllib3: 1.26.15 - uvicorn: 0.22.0 - virtualenv: 20.21.0 - w3lib: 1.21.0 - wandb: 0.15.2 - watchdog: 2.1.6 - wcwidth: 0.2.6 - webcolors: 1.13 - webencodings: 0.5.1 - websocket-client: 1.5.1 - websockets: 11.0.3 - werkzeug: 2.2.3 - whatthepatch: 1.0.2 - wheel: 0.38.4 - widgetsnbextension: 4.0.5 - wrapt: 1.14.1 - wurlitzer: 3.0.2 - xarray: 2022.11.0 - xlwings: 0.29.1 - yapf: 0.31.0 - yarl: 1.9.2 - zict: 2.2.0 - zipp: 3.11.0 - zope.interface: 5.4.0 * System: - OS: Darwin - architecture: - 64bit - - processor: i386 - python: 3.10.11 - release: 22.4.0 - version: Darwin Kernel Version 22.4.0: Mon Mar 6 21:00:17 PST 2023; root:xnu-8796.101.5~3/RELEASE_X86_64

More info

No response

basilbuw commented 1 year ago

I'm also wondering about this. I just installed pytorch on my 2019 macbook pro which still uses intel. But if I run: print(torch.backends.mps.is_available()) it says True. So is it actually supported on my macbook or is it just emulation?

mwwconsulting commented 1 year ago

My training runs ~10x faster than when I had CPU mode on.

mwwconsulting commented 1 year ago

I guess I improved something else along the way, it's 3 or 4x faster, but still... accelerator="cpu" devices="auto" Epoch 0/174 ━━━━━━━━━━━━━━━━━━━━ 65/7813 0:00:42 • 1:19:40 1.62it/s v_num: y_17 train_loss: 0.054 accelerator="mps" devices="1" Epoch 96/174 ━━━━━━━━━━━━━━━━━━━ 7813/7813 0:19:54 • 0:00:00 5.56it/s v_num: 1_16 train_loss: 0.028 val_loss: 0.029

mwwconsulting commented 1 year ago

Out of curiosity, I started another instance of my training run with mps and it has similar performance, so there may be more room for improvement - 2 scripts are running at 5 it/s.

IngLP commented 1 year ago

I have the same problem! MPS on Intel MacBookPro works:

import torch
if torch.backends.mps.is_available():
    mps_device = torch.device("mps")
    x = torch.ones(1, device=mps_device)
    print (x)
else:
    print ("MPS device not found.")

# prints "tensor([1.], device='mps:0')"

So the condition and platform.processor() in ("arm", "arm64") should be removed.

mickolka commented 1 year ago

Can confirm that MPS works with the removal of the condition. Also, seems like a duplicate of #15861

ringohoffman commented 10 months ago

Related: https://github.com/microsoft/vscode-python/issues/22614?