Open joprice opened 2 months ago
I got past that by using a newer version of torch (2.5.0) and transformers (4.43.3). Using Flair version 0.13.1 or 0.14.0 gives me the following issue when training a model ...
Traceback (most recent call last):
File "/Users/xxxxxx/Dev/flairNLP/train-models/train_model.py", line 165, in
I validate my instance of Torch is good and accepts "mps" using sample code from here - https://github.com/mrdbourke/pytorch-apple-silicon
Is MPS (Metal Performance Shader) built? True Is MPS available? True Using device: mps ... and the model works as expected there.
Describe the bug
When setting
flair.device
tomps
, the following error is thrown during training:To Reproduce
Expected behavior
Torch's mps support should be usable via flair.
Logs and Stack traces
No response
Screenshots
No response
Additional Context
No response
Environment
Versions:
Flair
0.13.1
Pytorch
2.3.1
Transformers
4.42.4
GPU
False