flairNLP / flair

A very simple framework for state-of-the-art Natural Language Processing (NLP)
https://flairnlp.github.io/flair/
Other
13.86k stars 2.1k forks source link

[Bug]: error when training using mps #3506

Open joprice opened 2 months ago

joprice commented 2 months ago

Describe the bug

When setting flair.device to mps, the following error is thrown during training:

RuntimeError: User specified an unsupported autocast device_type 'mps'

To Reproduce

flair.device = torch.device("mps")
... build and train model

Expected behavior

Torch's mps support should be usable via flair.

Logs and Stack traces

No response

Screenshots

No response

Additional Context

No response

Environment

Versions:

Flair

0.13.1

Pytorch

2.3.1

Transformers

4.42.4

GPU

False

BoilerToad commented 2 months ago

I got past that by using a newer version of torch (2.5.0) and transformers (4.43.3). Using Flair version 0.13.1 or 0.14.0 gives me the following issue when training a model ...

Traceback (most recent call last): File "/Users/xxxxxx/Dev/flairNLP/train-models/train_model.py", line 165, in main() File "/Users/xxxxxx/Dev/flairNLP/train-models/train_model.py", line 138, in main trainer.train( File "/Users/xxxxxx/VirtualEnvs/venv-flair-311/lib/python3.11/site-packages/flair/trainers/trainer.py", line 200, in train return self.train_custom(local_variables, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/xxxxxx/VirtualEnvs/venv-flair-311/lib/python3.11/site-packages/flair/trainers/trainer.py", line 600, in train_custom with torch.autocast(device_type=flair.device.type, enabled=use_amp): ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/xxxxxx/VirtualEnvs/venv-flair-311/lib/python3.11/site-packages/torch/amp/autocast_mode.py", line 230, in init dtype = torch.get_autocast_dtype(device_type) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ RuntimeError: unsupported scalarType

I validate my instance of Torch is good and accepts "mps" using sample code from here - https://github.com/mrdbourke/pytorch-apple-silicon

Is MPS (Metal Performance Shader) built? True Is MPS available? True Using device: mps ... and the model works as expected there.