BaderLab / saber

Saber is a deep-learning based tool for information extraction in the biomedical domain. Pull requests are welcome! Note: this is a work in progress. Many things are broken, and the codebase is not stable.
https://baderlab.github.io/saber/
MIT License
102 stars 17 forks source link

AMP only works on CUDA devices #167

Closed JohnGiorgi closed 5 years ago

JohnGiorgi commented 5 years ago

If a user tries to train a mixed-precision enabled model (bert-ner, bert-ner-re) with the CPU device and with Apex installed, they are faced with this error:

RuntimeError: Found param bert.embeddings.word_embeddings.weight with type torch.FloatTensor, expected torch.cuda.FloatTensor.
When using amp.initialize, you need to provide a model with parameters
located on a CUDA device before passing it no matter what optimization level
you chose. Use model.to('cuda') to use the default device.

Therefore, amp.initialize() should only be used when both Apex is installed and a CUDA device is available for training.