lucidrains / audiolm-pytorch

Implementation of AudioLM, a SOTA Language Modeling Approach to Audio Generation out of Google Research, in Pytorch
MIT License
2.39k stars 255 forks source link

Loss about CoarseTransformerWrapper #185

Closed asr-pub closed 1 year ago

asr-pub commented 1 year ago

https://github.com/lucidrains/audiolm-pytorch/blob/504a7e171d5246c9efa171c0f61a2b8f7f50ce0b/audiolm_pytorch/audiolm_pytorch.py#L1484-L1487

I think the loss above should be change to the below code ?

        if self.semantic_cross_entropy_loss_weight > 0 and exists(semantic_logits):
            final_loss = (
                semantic_loss * num_semantic_logits * self.semantic_cross_entropy_loss_weight +
                coarse_loss * num_coarse_logits
            ) / (num_semantic_logits + num_coarse_logits)
        else:
            final_loss = coarse_loss

        return final_loss
lucidrains commented 1 year ago

@asr-pub oh yes, thanks for raising this! think it should be addressed with this commit

asr-pub commented 1 year ago

Yes, it's right in this commit