writer / fitbert

Use BERT to Fill in the Blanks
https://pypi.org/project/fitbert/
Apache License 2.0
82 stars 14 forks source link

Fix PyTorch device error when loading custom model #18

Closed JasonObeid closed 4 years ago

JasonObeid commented 4 years ago

Using Python 3.8, fitbert 0.7.0, transformers 2.9.1, torch 1.5.0

when loading a custom Transformers model as described in the readme using: BertForMaskedLM.from_pretrained('path to pretrained')

A runtime error occurs: _Expected object of device type cuda but got device type cpu for argument #1 'self' in call to _th_index_select__

The issue occurs at line 151: tens = tens.to(self.device)

but adding self.bert.to(self.device) to line 41 fixes this issue

sam-writer commented 4 years ago

@JasonObeid I released this change, it should be available in version 0.9.0

JasonObeid commented 4 years ago

Looks great! Thanks for contributing!

No problem Sam, thanks for this great package!