InterDigitalInc / CompressAI

A PyTorch library and evaluation platform for end-to-end compression research
https://interdigitalinc.github.io/CompressAI/
BSD 3-Clause Clear License
1.19k stars 232 forks source link

About Rate Distortion Loss #199

Closed Indraa145 closed 1 year ago

Indraa145 commented 1 year ago

Hello, thank you for the work, I'd like to ask about the different formula of the Rate Distortion Loss from your custom training documentation and from your RateDistortionLoss class.

On your custom training documentation, the Rate Distortion Loss is defined as:

L=D+\lambda*R
x = torch.rand(1, 3, 64, 64)
net = Network()
x_hat, y_likelihoods = net(x)

# bitrate of the quantized latent
N, _, H, W = x.size()
num_pixels = N * H * W
bpp_loss = torch.log(y_likelihoods).sum() / (-math.log(2) * num_pixels)

# mean square error
mse_loss = F.mse_loss(x, x_hat)

# final loss term
loss = mse_loss + lmbda * bpp_loss

While on your RateDistortionLoss class, which is used in your examples/train.py, it is:

L=\lambda*255^2*D+R
N, _, H, W = target.size()
out = {}
num_pixels = N * H * W

out["bpp_loss"] = sum(
   (torch.log(likelihoods).sum() / (-math.log(2) * num_pixels))
   for likelihoods in output["likelihoods"].values()
)
out["mse_loss"] = self.mse(output["x_hat"], target)
out["loss"] = self.lmbda * 255**2 * out["mse_loss"] + out["bpp_loss"]

return out

I also notice that there's a difference in the bpp_loss calculation. In the RateDistortionClass, you sum all the bpp_loss. I also want to know why is this the case, are you summing all the bpp_loss across all the batches?

I'm wondering which loss is better to use? And is there paper that I can refer to regarding this? Thank you very much.

YodaEmbedding commented 1 year ago
  1. The bpp_loss calculation in both code samples is exactly the same for a model that only outputs $y$ and has total rate $R_y$ (e.g. the bmshj2018-factorized model). However, the second code sample will correctly calculate $R_y + R_z$ if $z$ is also outputted by the model. Note that output["likelihoods"].values() == (y_likelihoods, z_likelihoods) in that case.

  2. Both code samples estimate the average rate over the batches. This is what $R$ typically refers to -- the average or expected value of the log likelihoods over the entire sample space of images $\mathcal{X} = \{x_1, x_2, \ldots \}$ and their corresponding latents $\mathcal{Y} = \{g_a(x) : x \in \mathcal{X}\} = \{y_1, y_2, \ldots \}$,

    $$Ry = H(Y) = \mathbb{E}{y \in \mathcal{Y}}[-\log p(y)]$$

    In this case, we are taking a 16-sample (i.e. batch size) Monte Carlo estimate of the entropy.

  3. Optimizing a "16-sample average" $R,D$ is better than a "1-sample" $R,D$ since our goal is good average performance over a test dataset such as Kodak. In practice, I believe there isn't actually that much of a difference with SGD since it effectively tends to optimize for the average loss anyways. If we were to strongly optimize for the individual R-D performance via e.g. single sample R-D with an exaggerated loss such as $L = (R + D)^2$ for $L > 1$, the average R-D performance over the dataset might suffer. Or perhaps not.

Figure: Single sample R-D points (blue) and their average (orange) over the Kodak dataset.

Indraa145 commented 1 year ago

Thank you for the answer, but I'm still a bit confused with the different formula for the final loss term. In the custom training documentation, it's defined as:

L=D+\lambda*R

While in the RateDistortionClass, it's defined as:

L=\lambda*255^2*D+R

Why is the $\lambda$ multiplied with $255^2$ and $D$ here? As opposed to multiplied with $R$ like in the custom training documentation example.

YodaEmbedding commented 1 year ago

It's just a scaling constant. It has no tangible effect, as long as $\lambda$ is accordingly rescaled by the same amount.

Why 255? An 8-bit pixel has intensities in the interval $[0, 255]$, where $255 = 2^8 - 1$. The input image is "renormalized" so that $x \in [0, 1]$ instead. Then,

$$\lambda 255^2 D = \lambda 255^2 \text{MSE}(x, \hat{x}) = \lambda \text{MSE}(255 x, 255 \hat{x}) = \lambda D'$$

where $D'$ is a more typical "distortion" value that compression engineers may be used to.

It's not necessary to include $255$, though the $\lambda$ values given in the documentation assume the $255$ is there.

Indraa145 commented 1 year ago

Ah, I see. Thank you for the explanation.