InterDigitalInc / CompressAI

A PyTorch library and evaluation platform for end-to-end compression research
https://interdigitalinc.github.io/CompressAI/
BSD 3-Clause Clear License
1.15k stars 228 forks source link

Issue with Using CheckerboardMaskedConv2d from layer.py #291

Closed formioq closed 3 months ago

formioq commented 3 months ago

In the process of training using mbt2018 as the base model, I attempted to replace the original MaskedConv2d with CheckerboardMaskedConv2d (mask_type=B) as the context model. During testing, I noticed that while the PSNR increased to 33 (at epoch=100), the bpp dropped to 0.13 (in fact, from the first epoch, the bpp was around 0.2, and I set lambda=0.01). Such high results are clearly abnormal. I want to know if there are any issues with directly replacing the usage. Please tell me where the problem lies or if there is any solution.

4 3
YodaEmbedding commented 3 months ago

This usually occurs when some of the likelihood terms are missing during the bpp_loss computation. Perhaps the likelihood dict is not in the expected format:

assert likelihoods == {
    "y0": torch.Tensor(...),
    "y1": torch.Tensor(...),
    "z": torch.Tensor(...),
    ...
}

Similarly, during model eval, the bpp should measure the length of all the bytestrings. If bpp is also unrealistically small, it might be because it's measuring the length of the keys or list length.


Another alternative is to recursively flatten out the likelihoods/bytestrings:

def flatten_values(x, value_type=object):
    if isinstance(x, list) or isinstance(x, tuple) or isinstance(x, set):
        for v in x:
            yield from flatten_values(v)
    elif isinstance(x, dict):
        for v in x.values():
            yield from flatten_values(v)
    elif isinstance(x, value_type):
        yield x
    else:
        raise ValueError(f"Unexpected type {type(x)}")

class RateDistortionLoss(nn.Module):
    def forward(self, output, target):
        out["bpp_loss"] = sum(
            likelihoods.log2().sum() / -num_pixels
            for likelihoods in flatten_values(output["likelihoods"], torch.Tensor)
        )

def inference(model, x, vbr_stage=None, vbr_scale=None):
    ...
    bpp = sum(len(s) for s in flatten_values(out_enc["strings"], bytes)) * 8.0 / num_pixels

I might look into introducing a PR on this so this bug becomes less likely.