InterDigitalInc / CompressAI

A PyTorch library and evaluation platform for end-to-end compression research
https://interdigitalinc.github.io/CompressAI/
BSD 3-Clause Clear License
1.15k stars 228 forks source link

feat(models): add Checkerboard and ELIC #243

Closed YodaEmbedding closed 7 months ago

YodaEmbedding commented 1 year ago

Checkerboard and ELIC models.

Differences from original paper:


Checklist:


Compatibility check with Chandelier ELiC-ReImplemetation pretrained models

Download models from: https://github.com/VincentChandelier/ELiC-ReImplemetation#available-checkpoint

Save the following python script:

Click for python script ```python import argparse import torch FORWARD_MAPPING = { "g_a": "g_a", "g_s": "g_s", "h_a": "latent_codec.hyper.h_a", "h_s": "latent_codec.hyper.h_s", "entropy_bottleneck": "latent_codec.hyper.entropy_bottleneck", "cc_transforms.0": "latent_codec.y.channel_context.y1", "cc_transforms.1": "latent_codec.y.channel_context.y2", "cc_transforms.2": "latent_codec.y.channel_context.y3", "cc_transforms.3": "latent_codec.y.channel_context.y4", "context_prediction.0": "latent_codec.y.latent_codec.y0.context_prediction", "context_prediction.1": "latent_codec.y.latent_codec.y1.context_prediction", "context_prediction.2": "latent_codec.y.latent_codec.y2.context_prediction", "context_prediction.3": "latent_codec.y.latent_codec.y3.context_prediction", "context_prediction.4": "latent_codec.y.latent_codec.y4.context_prediction", "ParamAggregation.0": "latent_codec.y.latent_codec.y0.entropy_parameters", "ParamAggregation.1": "latent_codec.y.latent_codec.y1.entropy_parameters", "ParamAggregation.2": "latent_codec.y.latent_codec.y2.entropy_parameters", "ParamAggregation.3": "latent_codec.y.latent_codec.y3.entropy_parameters", "ParamAggregation.4": "latent_codec.y.latent_codec.y4.entropy_parameters", } REVERSE_MAPPING = { "latent_codec.y.latent_codec.y0.y.gaussian_conditional": "gaussian_conditional", "latent_codec.y.latent_codec.y1.y.gaussian_conditional": "gaussian_conditional", "latent_codec.y.latent_codec.y2.y.gaussian_conditional": "gaussian_conditional", "latent_codec.y.latent_codec.y3.y.gaussian_conditional": "gaussian_conditional", "latent_codec.y.latent_codec.y4.y.gaussian_conditional": "gaussian_conditional", } def _rename_key(key): found = False for src_prefix, dst_prefix in FORWARD_MAPPING.items(): if key.startswith(src_prefix): new_key = f"{dst_prefix}{key[len(src_prefix):]}" yield new_key found = True for dst_prefix, src_prefix in REVERSE_MAPPING.items(): if key.startswith(src_prefix): new_key = f"{dst_prefix}{key[len(src_prefix):]}" yield new_key found = True if found: return raise RuntimeError(f"Unmapped key: {key}") def rename_keys(state_dict): max_len = max(len(key) for key in state_dict.keys()) new_state_dict = {} for key, value in state_dict.items(): for new_key in _rename_key(key): print(f"{key:<{max_len}} -> {new_key:<{max_len}}") new_state_dict[new_key] = value return new_state_dict def build_parser(): parser = argparse.ArgumentParser() parser.add_argument( "--input", type=str, required=True, help="Path to the weights file" ) parser.add_argument( "--output", type=str, required=True, help="Path to the output file" ) return parser def main(): parser = build_parser() args = parser.parse_args() state_dict = torch.load(args.input) state_dict = rename_keys(state_dict) torch.save(state_dict, args.output) if __name__ == "__main__": main() ```

Then, run:

python rename_weights_elic.py --input=ELIC_0004_ft_3980_Plateau.pth.tar --output=ELIC_0004_ft_3980_Plateau_renamed.pth.tar
python rename_weights_elic.py --input=ELIC_0008_ft_3980_Plateau.pth.tar --output=ELIC_0008_ft_3980_Plateau_renamed.pth.tar
python rename_weights_elic.py --input=ELIC_0016_ft_3980_Plateau.pth.tar --output=ELIC_0016_ft_3980_Plateau_renamed.pth.tar
python rename_weights_elic.py --input=ELIC_0032_ft_3980_Plateau.pth.tar --output=ELIC_0032_ft_3980_Plateau_renamed.pth.tar
python rename_weights_elic.py --input=ELIC_0150_ft_3980_Plateau.pth.tar --output=ELIC_0150_ft_3980_Plateau_renamed.pth.tar
python rename_weights_elic.py --input=ELIC_0450_ft_3980_Plateau.pth.tar --output=ELIC_0450_ft_3980_Plateau_renamed.pth.tar

Then load the model checkpoint in CompressAI Trainer:

compressai-train ++model.name="elic2022-chandelier" ++hp.N=192 ++hp.M=320 ++hp.groups='[16,16,32,64,192]' ++criterion.lmbda=0.004 ++paths.model_checkpoint="ELIC_0004_ft_3980_Plateau_renamed.pth.tar"
fracape commented 1 year ago

base config (ReduceLROnPlateau)?

YodaEmbedding commented 1 year ago

Yes, default is fine. Judging by this implementation's reported results, it is possible to get very close to the paper's results. Like the paper, they also use 8000 images from ImageNet, but hopefully that has minimal effect.

fracape commented 11 months ago

Hi @YodaEmbedding, was trying to figure out where the problem is, but my results using the above (vimeo90k, default ReduceLROnPlateau) are not the expected performance:

Did you launch on your end or would you have a clue regarding a wrong setup here? checkerboard-vs-cheng2020-anchor

YodaEmbedding commented 11 months ago

I think that may have been due to a bug in compress/decompress. I fixed it. The current implementation seems to be working now with some limited testing. I think forward was fine, so it may have trained correctly. Thus, you may be able to get away with just loading the weights to check, and not have to train from scratch.


Summary of recent changes:


I plugged in Chandelier's pretrained weights (trained on ImageNet 8000) and then "finetuned" on Vimeo90K for a few epochs.

Chandelier pretrained model finetuned on Vimeo90K for 1 epoch Chandelier pretrained model finetuned on Vimeo90K for 4 epochs

Not unexpectedly, "finetuning" on a different, larger dataset actually improves RD performance further.

Not shown: finetuning 0 epochs, since I'm a bit too lazy to rerun things.


Suggestions:

YodaEmbedding commented 11 months ago

(Ignore.)

lin-toto commented 11 months ago

Regarding the - means trick, it wouldn't work with GMM. We can't really pre-compute the distributions with GMM because they depend on too many parameters.

lin-toto commented 11 months ago

Also, not sure about ELIC but FYI in my experiments Checkerboard + cheng2020 results differ by quite a lot between K=1 and K=3