ucam-smt / sgnmt

Decoding platform for machine translation research
http://ucam-smt.github.io/sgnmt/html/
Apache License 2.0
54 stars 17 forks source link

pytorch_fairseq.py - changed needed to run on GPU #9

Open wjbyrne opened 3 years ago

wjbyrne commented 3 years ago

Something like the following seems to be needed to run on GPU.

def predict_next(self): """Call the fairseq model.""" if self.usecuda: lprobs, = self.model.forward_decoder( torch.cuda.LongTensor([self.consumed]), self.encoder_outs ) lprobs[0, self.pad_id] = utils.NEGINF return np.array(lprobs[0].cpu()) else: lprobs, = self.model.forward_decoder( torch.LongTensor([self.consumed]), self.encoder_outs ) lprobs[0, self.pad_id] = utils.NEG_INF return np.array(lprobs[0])

fstahlberg commented 3 years ago

is this code tested? If yes I'm happy to update the code, or you can also make a PR / push it directly if you like.

wjbyrne commented 3 years ago

Hi Felix!

At the moment, i’ve only tested it as for the example in https://ucam-smt.github.io/sgnmt/html/tutorial_pytorch.html https://ucam-smt.github.io/sgnmt/html/tutorial_pytorch.html

Is that enough testing ? I could decode a few test sets with a few additional models, if you like.

btw, just to note that fairseq is now at version 1.10.x , and no longer compatible with sgnmt

Bill

On 7 Dec 2020, at 14:19, fstahlberg <notifications@github.com mailto:notifications@github.com> wrote:

is this code tested? If yes I'm happy to update the code, or you can also make a PR / push it directly if you like.

— You are receiving this because you authored the thread. Reply to this email directly, view it on GitHub https://github.com/ucam-smt/sgnmt/issues/9#issuecomment-739946858, or unsubscribe https://github.com/notifications/unsubscribe-auth/ACIKSOY6D7NZQQNXNX44YD3STTP5ZANCNFSM4UPKVMFA.

http://schema.org/ https://github.com/ucam-smt/sgnmt/issues/9#issuecomment-739946858 https://github.com/ucam-smt/sgnmt/issues/9#issuecomment-739946858 https://github.com/

fstahlberg commented 3 years ago

that's enough testing for now - I'll have a go at updating it to 1.10.x in a few weeks