rpryzant / delete_retrieve_generate

PyTorch implementation of the Delete, Retrieve Generate style transfer algorithm
MIT License
132 stars 26 forks source link

Getting an index out of bound error #25

Closed naveen-kinnal closed 3 years ago

naveen-kinnal commented 3 years ago

Hello. I am trying to run the code for the GYAFC (Src: Informal and Tgt: Formal) dataset. I have retained the same configuration as the yelp. However, while training the model, I get a weird exception while evaluating after each epoch, in this line

lens = [lens[j] for j in idx]

if sort:
        # sort sequence by descending length
        idx = [x[0] for x in sorted(enumerate(lens), key=lambda x: -x[1])]

    if idx is not None:
        lens = [lens[j] for j in idx]

in the method 'get_minibatch'. Could you please throw some light on what might be the issue here? Is it any issue with the dataset ? Also, is Sort really necessary?

rpryzant commented 3 years ago

Thanks for reaching out. The sort is (was, maybe this behavior is deprecated in newer versions of pytorch) needed to pack padded inputs before giving them to the LSTM encoder:

https://github.com/rpryzant/delete_retrieve_generate/blob/master/src/encoders.py#L49

rpryzant commented 3 years ago

https://github.com/pytorch/pytorch/issues/3584

naveen-kinnal commented 3 years ago

Thank you @rpryzant , I found a quick fix by equalizing both the source and target datasets. Working fine ! Although, I would be trying out the different padding mechanism soon :)