Closed amarzullo24 closed 2 years ago
There are two possible methods to handle this.
Run update_model
to modify the checkpoint:
python -m compressai.utils.update_model checkpoint.pth.tar
This also freezes the checkpoint, removes some state (e.g. optimizer), and adds a hash to the filename. If that is not desired, the alternative is...
After loading the model, call net.update(force=True)
:
net = models_video[args.model](quality=3)
net.update(force=True)
update_model
is necessary to update the parameters needed for entropy coding. These parameters are not used during training, so they can be added after training. More info here: https://github.com/InterDigitalInc/CompressAI/issues/5#issuecomment-724881519
Personally, I've created a wrapper that injects net.update(force=True)
into compressai.utils.eval_model
.
# Filename: personal_code/utils/compressai/update_and_eval_model.py
import compressai.utils.eval_model.__main__ as eval_model_main
import torch
import torch.nn as nn
from compressai.zoo import load_state_dict
from compressai.zoo.image import model_architectures as architectures
def load_checkpoint(arch: str, checkpoint_path: str) -> nn.Module:
ckpt = torch.load(checkpoint_path)
state_dict = ckpt["state_dict"] if "state_dict" in ckpt else ckpt
state_dict = load_state_dict(state_dict)
model = architectures[arch].from_state_dict(state_dict).eval()
model.update(force=True)
return model
def main(argv):
eval_model_main.__dict__["load_checkpoint"] = load_checkpoint
eval_model_main.main(argv)
# Filename: personal_code/utils/compressai/__main__.py
import importlib
import sys
if __name__ == "__main__":
_, util_name, *argv = sys.argv
if util_name == "update_and_eval_model":
from . import update_and_eval_model
main = update_and_eval_model.main
else:
module = importlib.import_module(
f"compressai.utils.{util_name}.__main__"
)
main = module.main
main(argv)
Usage:
python -m personal_code.utils.compressai update_and_eval_model checkpoint $DATASET_PATH --verbose -a=bmshj2018-factorized -p checkpoint.pth.tar > results.json
Hi, thanks for the detailed reply! I have tried to adapt both solutions but I still get the same error.
By using python -m compressai.utils.update_model checkpoint.pth.tar
I get an error:
Traceback (most recent call last):
File "/home/user/.conda/envs/compression/lib/python3.9/runpy.py", line 197, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/home/user/.conda/envs/compression/lib/python3.9/runpy.py", line 87, in _run_code
exec(code, run_globals)
File "/home/user/videocompression/compressai/compressai/utils/update_model/__main__.py", line 165, in <module>
main(sys.argv[1:])
File "/home/user/videocompression/compressai/compressai/utils/update_model/__main__.py", line 136, in main
net = model_cls.from_state_dict(state_dict)
File "/home/user/videocompression/compressai/compressai/models/google.py", line 285, in from_state_dict
N = state_dict["g_a.0.weight"].size(0)
KeyError: 'g_a.0.weight'
Could it be related to the modification I have done in #114?
I have written the following method:
def load_checkpoint(checkpoint_path):
from compressai.zoo import load_state_dict
from compressai.zoo import models_video
ckpt = torch.load(checkpoint_path)
state_dict = ckpt["state_dict"] if "state_dict" in ckpt else ckpt
state_dict = load_state_dict(state_dict)
model = models_video['ssf2020'](quality=3).from_state_dict(state_dict).eval()
model.update(force=True)
return model
But I still get the same error of size mismatch.
I am trying to digging into this by myself but I think I am missing something.. maybe you could spot the problem more easily.
Thanks for your help
Thanks for the report. We've confirmed that this is a bug, so we will fix it soon. (Note that the current release v1.2.0b2 is the beta version, so new features and weights for the video model could be fragile.)
In the meantime, you can properly change "load_state_dict" in google.py under the video folder.
Simply you can leave some codes only related to entropy_bottleneck register updates and "super().load_state_dict(state_dict)" at the end in your local.
Thank you.
Thank you for the answer. I temporary fixed it by putting the following lines in the training script (e.g. train_video.py, right after the model initialization:
net = models_video[args.model](quality=3)
net = net.to(device)
# needed for correct checkpoint saving
net.img_hyperprior.entropy_bottleneck.update()
net.res_hyperprior.entropy_bottleneck.update()
net.motion_hyperprior.entropy_bottleneck.update()
This allows to load the checkpoint without modifying the load_state_dict
.
Hope it helps
Hi @emmeduz,
One of the recent commits (https://github.com/InterDigitalInc/CompressAI/commit/84bf9203a708dcd09fd1d2fb7dea7c65a48ff94c) hopefully fixes your issue. So, please check out the master up-to-date.
Thanks for all the information above. It was really helpful to resolve the issue.
Bug
I am trying to load a checkpoint after training the ssf2020 model. However, when I load the checkpoint, a size mismatch error occurs
To Reproduce
Run the following script (provided you already have a "checkpoint.pth.tar")
You should get the following error:
Expected behavior
The model state loads properly.
Environment
Please copy and paste the output from
python3 -m torch.utils.collect_env
Additional context
I have noticed that if I print the actual value of the parameters in the stack trace, the corresponding tensor in the state_dict is actually empty:
update
As expected, by commenting lines 439-482 of compressai.models.video.google.py I get the loading work. However, I guess this is not a proper solution.