lucidrains / voicebox-pytorch

Implementation of Voicebox, new SOTA Text-to-speech network from MetaAI, in Pytorch
MIT License
589 stars 49 forks source link

There might be a bug in the loss calculation #6

Closed yzmyyff closed 1 year ago

yzmyyff commented 1 year ago

At first, we set all masked parts = 0.

loss = loss.masked_fill(mask, 0.)

In the paper, they set all unmasked parts = 0 and average masked parts.

If our mask is precisely the inverse of the mask in the paper. we need to average unmasked parts. but we also count the masked parts

den = mask.sum(dim = -1).clamp(min = 1e-5)
lucidrains commented 1 year ago

@yzmyyff thanks Kuang!