InterDigitalInc / CompressAI

A PyTorch library and evaluation platform for end-to-end compression research
https://interdigitalinc.github.io/CompressAI/
BSD 3-Clause Clear License
1.21k stars 232 forks source link

Compute pixel-wise bitrate allocation in latent space #265

Open danishnazir opened 11 months ago

danishnazir commented 11 months ago

Hi,

Is there a way to compute per-pixel bitrate info in RGB space of the latent variable y or z during evaluation?

Regards

YodaEmbedding commented 11 months ago

The likelihood map $l = -\log2 p{\hat{y}}(\hat{y})$ of dimensions $M_y \times \frac{H}{2^4} \times \frac{W}{2^4}$ can be calculated exactly, so you can plot each the bit cost of each latent element. Since it's 3D, you can flatten it to 2D using a sum along the channel dimension.

In bmshj2018-factorized, each $\hat{y}_i$ latent pixel affects reconstructed $\hat{x}$ pixels within a window of 9x9. So, one way to map the $l$ back to the image domain is just to upsample and interpolate.

I guess a more advanced method for estimating rate costs of encoding each pixel would involve training some sort of rate estimation model or maybe some Grad-CAM-like technique.

danishnazir commented 11 months ago

Thank you for your answer. I try to do it in the following way, can you maybe please point out any mistake?

model.eval() #model = hyperprior
model.update()

y = model.g_a(x)
y_hat, likelihood_map = model.entropy_bottleneck(y)
likelihood_map = likelihood_map [0].detach().cpu().numpy()
pixel_bitrates = likelihood_map.sum(dim=1) #channel-wise sum

Now should i just upsample pixel_bitrates to x using the interpolation techniques like billinear somehow?

YodaEmbedding commented 11 months ago

bmshj2018-factorized NLLs (negative log likelihoods):

import matplotlib.pyplot as plt
import torch.nn.functional as F
from compressai.zoo import bmshj2018_factorized
from PIL import Image
from torchvision import transforms

device = "cuda"

for quality in [1, 2, 3, 4, 5, 6, 7, 8]:
    model = bmshj2018_factorized(quality=quality, pretrained=True).eval().to(device)
    img = Image.open("/data/datasets/kodak/test/kodim01.png").convert("RGB")
    x = transforms.ToTensor()(img).unsqueeze(0).to(device)
    _, _, H, W = x.shape

    y = model.g_a(x)
    y_hat, y_likelihoods = model.entropy_bottleneck(y)
    y_nll = -y_likelihoods.log2()
    scale_factor = x.shape[-1] / y.shape[-1]
    x_nll = F.interpolate(
        y_nll.sum(dim=1, keepdim=True) / scale_factor**2,
        scale_factor=scale_factor,
        mode="bilinear",
        align_corners=False,
    )

    fig, ax = plt.subplots(tight_layout=True)
    im = ax.imshow(x_nll[0, 0].detach().cpu().numpy(), vmin=0)
    fig.colorbar(im)
    ax.set(title=f"bmshj2018-factorized  q={quality}")
    fig.savefig(f"x_nll_q={quality}.png")
    plt.close(fig)

Just for fun, to animate:

ffmpeg -framerate 1 -pattern_type glob -i 'x_nll_q*.png' -f apng -plays 0 x_nll_all.png
Input Negative log likelihoods (bits)
image x_nll_all

Low frequency regions consume the least rate.