lucidrains / audiolm-pytorch

Implementation of AudioLM, a SOTA Language Modeling Approach to Audio Generation out of Google Research, in Pytorch
MIT License
2.32k stars 249 forks source link

Removal of the last token id from fine_token_ids in FineTransformerWrapper.forward() #261

Closed biendltb closed 6 months ago

biendltb commented 6 months ago

Hi,

Thanks for the amazing work. This work is a great complementary part of the paper that helps clarify lots of things that did not have enough light in the AudioLM paper.

When reading the code, one small thing that I don't understand is why the last token id in the fine token ids is removed in the training.

        if return_loss:
            coarse_labels = coarse_token_ids
            fine_labels = fine_token_ids
            fine_token_ids = fine_token_ids[:, :-1]

https://github.com/lucidrains/audiolm-pytorch/commit/65495ad5b060bfc90d056d4a161d2df5060eea6a#diff-96a5ee045c1df07f3125d9b4189130620f229785b36cebb86c95b0646f0d744dR1487

From the above commit, the eos id is not appended to the fine token ids used for the transformer decoder (i.e. we don't have the eos id as the last token ids). However, we still remove the last token id from the fine token ids. I think that this token id is still needed for training.

Could anyone clarify this? Thank you for your help.

biendltb commented 6 months ago

Closed this issue since the length of fine token ids is tied to the length of pre-existing coarse token ids. Thus, the model does not need to predict the EOS id.