AndriyMulyar / bert_document_classification

architectures and pre-trained models for long document classification.
154 stars 47 forks source link

Missing device as input in the non LSTM models. #1

Closed LittlePea13 closed 5 years ago

LittlePea13 commented 5 years ago

The forward passes for the non-LSTM based architectures are missing the device input which is given when called in the document_bert script, leading to an error.

It is fixed by just adding it as in the LSTM one, def forward(self, document_batch: torch.Tensor, document_sequence_lengths: list, freeze_bert=False, device='cuda'): In lines 72 and 117 in the document_bert_architectures file.

(I can add it as pull request but assumed you would rather fix it yourself) PS: Thanks for your work!

AndriyMulyar commented 5 years ago

Thank you! Will get around to this as it will cause issues on multi-gpu machines.

On Fri, Nov 1, 2019, 9:45 AM Pere Lluis notifications@github.com wrote:

The forward passes for the non-LSTM based architectures are missing the device input which is given when called in the document_bert script, leading to an error.

It is fixed by just adding it as in the LSTM one, def forward(self, document_batch: torch.Tensor, document_sequence_lengths: list, freeze_bert=False, device='cuda'): In lines 72 and 117 in the document_bert_architectures file.

(I can add it as pull request but assumed you would rather fix it yourself) PS: Thanks for your work!

— You are receiving this because you are subscribed to this thread. Reply to this email directly, view it on GitHub https://github.com/AndriyMulyar/bert_document_classification/issues/1?email_source=notifications&email_token=ADJ4TBQQ5PM3UHO6U5MEPZTQRQXHJA5CNFSM4JH23P32YY3PNVWWK3TUL52HS4DFUVEXG43VMWVGG33NNVSW45C7NFSM4HWDMRBA, or unsubscribe https://github.com/notifications/unsubscribe-auth/ADJ4TBWIGASIF6BCO4F5DGDQRQXHJANCNFSM4JH23P3Q .

AndriyMulyar commented 5 years ago

done