SamLynnEvans / Transformer

Transformer seq2seq model, program that can build a language translator from parallel corpus
Apache License 2.0
1.35k stars 350 forks source link

IndexError when testing with python translate.py #3

Open romran opened 5 years ago

romran commented 5 years ago

Hello, Many thanks for sharing project, unfortunatelly getting IndexError: index 0 is out of bounds for dimension 0 with size 0 when running on floydhub python translate.py -load_weights weights -src_lang en -trg_lang fr -floyd -no_cuda and parsing text for translation. May someone know where could be the problem?

error

fabrahman commented 5 years ago

@romran I am getting the same problem for some input texts not all of them. I am not sure where is the issue though

oscarberonius commented 5 years ago

This is due to a bug in the beam search module. I made a pull request with a fix here https://github.com/SamLynnEvans/Transformer/pull/5

You can copy/paste the changes if you just want to try it out :)

fabrahman commented 5 years ago

@oscarberonius Hi Thanks for your effort. It is still giving index error for some inputs. I tried previous inputs which I got error index and it is working on them now. But it throw error for other inputs now.

oscarberonius commented 5 years ago

Hi @Hannabrahman. That is indeed strange, I still don't have any problems with any of my inputs. Maybe yours contain multiple start-tokens?

Anyway, let me know if you find a solution. Good luck!

xiaohongniua commented 5 years ago

change the code in beam.py as following. It seems there is a bug in the origin code.

the origin code try to stop when there are k eos in the final top k sentence,however, its possible for topi containing more than 1 eos,when this happens, there is a bug

    if (outputs==eos_tok).cpu().numpy().argmax(axis=1).nonzero()[0].shape[0]== opt.k:
        alpha = 0.7
        div = 1 / (torch.tensor(((outputs == eos_tok).cpu().numpy().argmax(axis=1))).type_as(log_scores) ** alpha)
        _, ind = torch.max(log_scores * div, 1)
        ind = ind.data[0]
        break