Open danishnazir opened 11 months ago
The likelihood map $l = -\log2 p{\hat{y}}(\hat{y})$ of dimensions $M_y \times \frac{H}{2^4} \times \frac{W}{2^4}$ can be calculated exactly, so you can plot each the bit cost of each latent element. Since it's 3D, you can flatten it to 2D using a sum along the channel dimension.
In bmshj2018-factorized, each $\hat{y}_i$ latent pixel affects reconstructed $\hat{x}$ pixels within a window of 9x9. So, one way to map the $l$ back to the image domain is just to upsample and interpolate.
I guess a more advanced method for estimating rate costs of encoding each pixel would involve training some sort of rate estimation model or maybe some Grad-CAM-like technique.
Thank you for your answer. I try to do it in the following way, can you maybe please point out any mistake?
model.eval() #model = hyperprior
model.update()
y = model.g_a(x)
y_hat, likelihood_map = model.entropy_bottleneck(y)
likelihood_map = likelihood_map [0].detach().cpu().numpy()
pixel_bitrates = likelihood_map.sum(dim=1) #channel-wise sum
Now should i just upsample pixel_bitrates
to x
using the interpolation techniques like billinear
somehow?
import matplotlib.pyplot as plt
import torch.nn.functional as F
from compressai.zoo import bmshj2018_factorized
from PIL import Image
from torchvision import transforms
device = "cuda"
for quality in [1, 2, 3, 4, 5, 6, 7, 8]:
model = bmshj2018_factorized(quality=quality, pretrained=True).eval().to(device)
img = Image.open("/data/datasets/kodak/test/kodim01.png").convert("RGB")
x = transforms.ToTensor()(img).unsqueeze(0).to(device)
_, _, H, W = x.shape
y = model.g_a(x)
y_hat, y_likelihoods = model.entropy_bottleneck(y)
y_nll = -y_likelihoods.log2()
scale_factor = x.shape[-1] / y.shape[-1]
x_nll = F.interpolate(
y_nll.sum(dim=1, keepdim=True) / scale_factor**2,
scale_factor=scale_factor,
mode="bilinear",
align_corners=False,
)
fig, ax = plt.subplots(tight_layout=True)
im = ax.imshow(x_nll[0, 0].detach().cpu().numpy(), vmin=0)
fig.colorbar(im)
ax.set(title=f"bmshj2018-factorized q={quality}")
fig.savefig(f"x_nll_q={quality}.png")
plt.close(fig)
Just for fun, to animate:
ffmpeg -framerate 1 -pattern_type glob -i 'x_nll_q*.png' -f apng -plays 0 x_nll_all.png
Input | Negative log likelihoods (bits) |
---|---|
Low frequency regions consume the least rate.
Hi,
Is there a way to compute per-pixel bitrate info in RGB space of the latent variable y or z during evaluation?
Regards