Closed yzmyyff closed 1 year ago
At first, we set all masked parts = 0.
loss = loss.masked_fill(mask, 0.)
In the paper, they set all unmasked parts = 0 and average masked parts.
If our mask is precisely the inverse of the mask in the paper. we need to average unmasked parts. but we also count the masked parts
den = mask.sum(dim = -1).clamp(min = 1e-5)
@yzmyyff thanks Kuang!
At first, we set all masked parts = 0.
In the paper, they set all unmasked parts = 0 and average masked parts.
If our mask is precisely the inverse of the mask in the paper. we need to average unmasked parts. but we also count the masked parts