Open HarperHao opened 11 months ago
@HarperHao @jychoi118 I have the same question, why does the lambda_t equal to 1? Do you find out the reason?
@HarperHao @jychoi118 I have the same question, why does the lambda_t equal to 1? Do you find out the reason?
Sorry,Finally I modify the code to replace 1 with lambda
@HarperHao @LinWeiJeff @jychoi118 The way I understand it is as follows:
The original VLB loss $L_t$ can be rewritten as the expected value of a weighted MSE of the noise prediction and the actual noise. This is described in equation 5:
However, we are using $L_{simple}$, which drops this weighting factor (so the weighting factor is one). In terms of VLB this means that we are multiplying equation 5 with the inverse of the weighting factor. The authors refere to this inverse as $\lambda_t$ or $baseline$
later on in equation 8 the authors present a new weighting scheme that replaces $\lambda_t$ with $\lambda_t'$
They then state the following:
Note how the authors are using $Lt$ and NOT $L{simple}$ in the highlighted text. Remember, $L_t$ still has the weighting factor attached to it and to get rid of it you need to multiply by $\lambdat$. In the code however we are using $L{simple}$ which is inherently already multiplied by $\lambdat$. This means we can simply multiply $L{simple}$ with $\frac{1}{(k + SNR(t))^\gamma}$ to get the P2 weighting scheme.
I hope my explanation was clear. @jychoi118 Please correct me if I misinterpreted any of your claims.
weight = _extract_into_tensor(1 / (self.p2_k + self.snr)**self.p2_gamma, t, target.shape)
Why is the numerator 1 in the code, not the lambda mentioned in the equation 8?