IBM / pytorch-seq2seq

An open source framework for seq2seq models in PyTorch.
https://ibm.github.io/pytorch-seq2seq/public/index.html
Apache License 2.0
1.5k stars 376 forks source link

Decode function in decoder #179

Closed ihungalexhsu closed 6 years ago

ihungalexhsu commented 6 years ago

In the decode function of the Decoder.py, when calculating the length, I think there are some issues. To be more specific, update_idx = ((lengths > step) & eos_batches) != 0 lengths[update_idx] = len(sequence_symbols) the update_idx would return [0,1,0,1,.....] sequence not the index that are nonzero I think it should be update_idx = np.nonzero((lengths > step) & eos_batches)

Can you check this part? I'm not pretty sure that I'm correct. Appreciate.

pskrunner14 commented 6 years ago

Hi @ihungalexhsu I think they're actually the same operation as shown below:

>>> import numpy as np
>>> a = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
>>> b = a
>>> a
array([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10])
>>> b
array([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10])
>>> a[a > 5] = 1  # access using boolean array with same dims
>>> a
array([1, 2, 3, 4, 5, 1, 1, 1, 1, 1])
>>> b[np.array([5, 6, 7, 8, 9])] = 1  # access using indices
>>> b
array([1, 2, 3, 4, 5, 1, 1, 1, 1, 1])

At first I thought it was a bug too but the end result says otherwise.

ihungalexhsu commented 6 years ago

Got it! Thanks for your clarify