InterDigitalInc / CompressAI

A PyTorch library and evaluation platform for end-to-end compression research
https://interdigitalinc.github.io/CompressAI/
BSD 3-Clause Clear License
1.19k stars 232 forks source link

size mismatch when loading checkpoint (ssf2020) #113

Closed amarzullo24 closed 2 years ago

amarzullo24 commented 2 years ago

Bug

I am trying to load a checkpoint after training the ssf2020 model. However, when I load the checkpoint, a size mismatch error occurs

To Reproduce

Run the following script (provided you already have a "checkpoint.pth.tar")

net = models_video[args.model](quality=3) #args.model -> default ssf2020
net = net.to(device)

torch.save(net.state_dict(), "checkpoint.pth.tar")
net.load_state_dict(torch.load("checkpoint.pth.tar"))

You should get the following error:

Traceback (most recent call last):
  File "/home/user/videocompression/main.py", line 480, in <module>
    main(sys.argv[1:])
  File "/home/user/videocompression/main.py", line 427, in main
    net.load_state_dict(torch.load("checkpoint.pth.tar"))
  File "/home/user/videocompression/compressai/compressai/models/video/google.py", line 484, in load_state_dict
    super().load_state_dict(state_dict)
  File "/home/user/.conda/envs/compression/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1223, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for ScaleSpaceFlow:
        size mismatch for img_hyperprior.entropy_bottleneck._offset: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([192]).
        size mismatch for img_hyperprior.entropy_bottleneck._quantized_cdf: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([192, 23]).
        size mismatch for img_hyperprior.entropy_bottleneck._cdf_length: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([192]).
        size mismatch for res_hyperprior.entropy_bottleneck._offset: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([192]).
        size mismatch for res_hyperprior.entropy_bottleneck._quantized_cdf: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([192, 23]).
        size mismatch for res_hyperprior.entropy_bottleneck._cdf_length: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([192]).
        size mismatch for motion_hyperprior.entropy_bottleneck._offset: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([192]).
        size mismatch for motion_hyperprior.entropy_bottleneck._quantized_cdf: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([192, 23]).
        size mismatch for motion_hyperprior.entropy_bottleneck._cdf_length: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([192]).

Expected behavior

The model state loads properly.

Environment

Please copy and paste the output from python3 -m torch.utils.collect_env

Collecting environment information...
PyTorch version: 1.8.0+cu111
Is debug build: False
CUDA used to build PyTorch: 11.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.3 LTS (x86_64)
GCC version: (Ubuntu 9.3.0-17ubuntu1~20.04) 9.3.0
Clang version: 10.0.0-4ubuntu1
CMake version: version 3.16.3

Python version: 3.9 (64-bit runtime)
Is CUDA available: True
CUDA runtime version: Could not collect
GPU models and configuration:
GPU 0: NVIDIA A100-PCIE-40GB
GPU 1: NVIDIA A100-PCIE-40GB
  MIG 2g.10gb     Device  0:
GPU 2: NVIDIA A100-PCIE-40GB

Nvidia driver version: 470.57.02
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.2.4

HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.21.2
[pip3] pytorch-msssim==0.2.1
[pip3] torch==1.8.0+cu111
[pip3] torchaudio==0.8.0
[pip3] torchvision==0.9.0+cu111
[conda] blas                      1.0                         mkl
[conda] cudatoolkit               10.2.89              hfd86e86_1
[conda] libblas                   3.9.0            12_linux64_mkl    conda-forge
[conda] libcblas                  3.9.0            12_linux64_mkl    conda-forge
[conda] liblapack                 3.9.0            12_linux64_mkl    conda-forge
[conda] liblapacke                3.9.0            12_linux64_mkl    conda-forge
[conda] mkl                       2021.4.0           h06a4308_640
[conda] mkl-service               2.4.0            py39h7f8727e_0
[conda] mkl_fft                   1.3.1            py39hd3c417c_0
[conda] mkl_random                1.2.2            py39h51133e4_0
[conda] numpy                     1.21.2           py39h20f2e39_0
[conda] numpy-base                1.21.2           py39h79a1101_0
[conda] pytorch-msssim            0.2.1                    pypi_0    pypi
[conda] pytorch-mutex             1.0                        cuda    pytorch
[conda] torch                     1.8.0+cu111              pypi_0    pypi
[conda] torchaudio                0.8.0                    pypi_0    pypi
[conda] torchvision               0.9.0+cu111              pypi_0    pypi

Additional context

I have noticed that if I print the actual value of the parameters in the stack trace, the corresponding tensor in the state_dict is actually empty:

net = models_video[args.model](quality=3)
print(net.state_dict()["img_hyperprior.entropy_bottleneck._offset"]) # **output: -> tensor([], device='cuda:0', dtype=torch.int32)**

update

As expected, by commenting lines 439-482 of compressai.models.video.google.py I get the loading work. However, I guess this is not a proper solution.

YodaEmbedding commented 2 years ago

There are two possible methods to handle this.

  1. Run update_model to modify the checkpoint:

    python -m compressai.utils.update_model checkpoint.pth.tar

    This also freezes the checkpoint, removes some state (e.g. optimizer), and adds a hash to the filename. If that is not desired, the alternative is...

  2. After loading the model, call net.update(force=True):

    net = models_video[args.model](quality=3)
    net.update(force=True)

update_model is necessary to update the parameters needed for entropy coding. These parameters are not used during training, so they can be added after training. More info here: https://github.com/InterDigitalInc/CompressAI/issues/5#issuecomment-724881519


Personally, I've created a wrapper that injects net.update(force=True) into compressai.utils.eval_model.

# Filename: personal_code/utils/compressai/update_and_eval_model.py 

import compressai.utils.eval_model.__main__ as eval_model_main
import torch
import torch.nn as nn
from compressai.zoo import load_state_dict
from compressai.zoo.image import model_architectures as architectures

def load_checkpoint(arch: str, checkpoint_path: str) -> nn.Module:
    ckpt = torch.load(checkpoint_path)
    state_dict = ckpt["state_dict"] if "state_dict" in ckpt else ckpt
    state_dict = load_state_dict(state_dict)
    model = architectures[arch].from_state_dict(state_dict).eval()
    model.update(force=True)
    return model

def main(argv):
    eval_model_main.__dict__["load_checkpoint"] = load_checkpoint
    eval_model_main.main(argv)
# Filename: personal_code/utils/compressai/__main__.py

import importlib
import sys

if __name__ == "__main__":
    _, util_name, *argv = sys.argv
    if util_name == "update_and_eval_model":
        from . import update_and_eval_model

        main = update_and_eval_model.main
    else:
        module = importlib.import_module(
            f"compressai.utils.{util_name}.__main__"
        )
        main = module.main

    main(argv)

Usage:

python -m personal_code.utils.compressai update_and_eval_model checkpoint $DATASET_PATH --verbose -a=bmshj2018-factorized -p checkpoint.pth.tar > results.json
amarzullo24 commented 2 years ago

Hi, thanks for the detailed reply! I have tried to adapt both solutions but I still get the same error.

Method 1

By using python -m compressai.utils.update_model checkpoint.pth.tar I get an error:

Traceback (most recent call last):
  File "/home/user/.conda/envs/compression/lib/python3.9/runpy.py", line 197, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/user/.conda/envs/compression/lib/python3.9/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/home/user/videocompression/compressai/compressai/utils/update_model/__main__.py", line 165, in <module>
    main(sys.argv[1:])
  File "/home/user/videocompression/compressai/compressai/utils/update_model/__main__.py", line 136, in main
    net = model_cls.from_state_dict(state_dict)
  File "/home/user/videocompression/compressai/compressai/models/google.py", line 285, in from_state_dict
    N = state_dict["g_a.0.weight"].size(0)
KeyError: 'g_a.0.weight'

Could it be related to the modification I have done in #114?

Method 2

I have written the following method:

def load_checkpoint(checkpoint_path):
    from compressai.zoo import load_state_dict
    from compressai.zoo import models_video

    ckpt = torch.load(checkpoint_path)
    state_dict = ckpt["state_dict"] if "state_dict" in ckpt else ckpt
    state_dict = load_state_dict(state_dict)

    model = models_video['ssf2020'](quality=3).from_state_dict(state_dict).eval()
    model.update(force=True)
    return model

But I still get the same error of size mismatch.

I am trying to digging into this by myself but I think I am missing something.. maybe you could spot the problem more easily.

Thanks for your help

chyomin06 commented 2 years ago

Thanks for the report. We've confirmed that this is a bug, so we will fix it soon. (Note that the current release v1.2.0b2 is the beta version, so new features and weights for the video model could be fragile.)

In the meantime, you can properly change "load_state_dict" in google.py under the video folder.

Simply you can leave some codes only related to entropy_bottleneck register updates and "super().load_state_dict(state_dict)" at the end in your local.

Thank you.

amarzullo24 commented 2 years ago

Thank you for the answer. I temporary fixed it by putting the following lines in the training script (e.g. train_video.py, right after the model initialization:

    net = models_video[args.model](quality=3)
    net = net.to(device)

    # needed for correct checkpoint saving
    net.img_hyperprior.entropy_bottleneck.update()
    net.res_hyperprior.entropy_bottleneck.update()
    net.motion_hyperprior.entropy_bottleneck.update()

This allows to load the checkpoint without modifying the load_state_dict. Hope it helps

chyomin06 commented 2 years ago

Hi @emmeduz,

One of the recent commits (https://github.com/InterDigitalInc/CompressAI/commit/84bf9203a708dcd09fd1d2fb7dea7c65a48ff94c) hopefully fixes your issue. So, please check out the master up-to-date.

Thanks for all the information above. It was really helpful to resolve the issue.