Closed LittlePea13 closed 5 years ago
Thank you! Will get around to this as it will cause issues on multi-gpu machines.
On Fri, Nov 1, 2019, 9:45 AM Pere Lluis notifications@github.com wrote:
The forward passes for the non-LSTM based architectures are missing the device input which is given when called in the document_bert script, leading to an error.
It is fixed by just adding it as in the LSTM one, def forward(self, document_batch: torch.Tensor, document_sequence_lengths: list, freeze_bert=False, device='cuda'): In lines 72 and 117 in the document_bert_architectures file.
(I can add it as pull request but assumed you would rather fix it yourself) PS: Thanks for your work!
— You are receiving this because you are subscribed to this thread. Reply to this email directly, view it on GitHub https://github.com/AndriyMulyar/bert_document_classification/issues/1?email_source=notifications&email_token=ADJ4TBQQ5PM3UHO6U5MEPZTQRQXHJA5CNFSM4JH23P32YY3PNVWWK3TUL52HS4DFUVEXG43VMWVGG33NNVSW45C7NFSM4HWDMRBA, or unsubscribe https://github.com/notifications/unsubscribe-auth/ADJ4TBWIGASIF6BCO4F5DGDQRQXHJANCNFSM4JH23P3Q .
done
The forward passes for the non-LSTM based architectures are missing the device input which is given when called in the document_bert script, leading to an error.
It is fixed by just adding it as in the LSTM one,
def forward(self, document_batch: torch.Tensor, document_sequence_lengths: list, freeze_bert=False, device='cuda'):
In lines 72 and 117 in the document_bert_architectures file.(I can add it as pull request but assumed you would rather fix it yourself) PS: Thanks for your work!