esalesky / visrep

This repository contains an extension of fairseq for pixel / visual representations for machine translation.
https://arxiv.org/abs/2104.08211
MIT License
34 stars 5 forks source link

Interface to load model checkpoint #1

Closed long21wt closed 2 years ago

long21wt commented 2 years ago

Hi,

Currently, there is no method to load the checkpoint together with their dictionaries. It would be great if you have a PyTorch Hub interface for the model. Inference would be easier, and it could be reused for further applications.

Thanks

esalesky commented 2 years ago

Adding to the PyTorch Hub is a great suggestion and I will look into it this coming week!

For an example of how to use the checkpoint and dictionary files on Zenodo on the commandline, see here. You can also queue multiple files for inference using the translate scripts in grid_scripts, similar to fairseq recipes.

long21wt commented 2 years ago

Thanks, I've already tried the given CLI command. It asks for image-font-path argument, if I pass the path in E.g. --image-font-path fairseq/data/visual/fonts/NotoMono-Regular.ttf (or other fonts), there will be an error:

2022-01-09 23:47:13 | INFO | fairseq.tasks.visual_text | dictionary size (de-en/dict.en.txt): 10,072
2022-01-09 23:47:13 | INFO | fairseq.data.visual.image_generator | Loading fonts from fairseq/data/visual/fonts/NotoMono-Regular.ttf
2022-01-09 23:47:13 | INFO | fairseq.data.visual.image_generator | Created 8pt NotoMono-Regular.ttf with image height 19 and est. char width 21
2022-01-09 23:47:13 | INFO | fairseq.data.visual.image_generator | Image window size 30 stride 20
2022-01-09 23:47:13 | INFO | fairseq_cli.interactive | loading model(s) from de-en/checkpoint_best.pt
2022-01-09 23:47:13 | INFO | fairseq.modules.visual | 1Layer embedding (norm: True; bridge relu: False) from 20 * 19 = 380 to 512; conv2d kernel size: (3, 3)
Traceback (most recent call last):
  File "/usr/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/home/me/visrep/fairseq_cli/interactive.py", line 318, in <module>
    cli_main()
  File "/home/me/visrep/fairseq_cli/interactive.py", line 314, in cli_main
    distributed_utils.call_main(convert_namespace_to_omegaconf(args), main)
  File "/home/me/visrep/fairseq/distributed/utils.py", line 364, in call_main
    main(cfg, **kwargs)
  File "/home/me/visrep/fairseq_cli/interactive.py", line 147, in main
    models, _model_args = checkpoint_utils.load_model_ensemble(
  File "/home/me/visrep/fairseq/checkpoint_utils.py", line 297, in load_model_ensemble
    ensemble, args, _task = load_model_ensemble_and_task(
  File "/home/me/visrep/fairseq/checkpoint_utils.py", line 358, in load_model_ensemble_and_task
    model.load_state_dict(state["model"], strict=strict, model_cfg=cfg.model)
  File "/home/me/visrep/fairseq/models/fairseq_model.py", line 115, in load_state_dict
    return super().load_state_dict(new_state_dict, strict)
  File "/home/me/visrep/visrep/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1482, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for VisualTextTransformerModel:
        size mismatch for encoder.cnn_embedder.bridge.weight: copying a param with shape torch.Size([512, 440]) from checkpoint, the shape in current model is torch.Size([512, 380]).

Seems like an issue with the convolutional block.

esalesky commented 2 years ago

Ah, it looks like there is a size mismatch, likely because you're using a different font variant than was used in the checkpoints, which can change the image size — the error message says the bridge layer is size 380 rather than 440 as expected by the models. The parameters used for training are serialized into the model; you only need to pass the image-font-path because your repo path will be different than mine.

Our models use:
NotoSans-Regular.ttf for French, German, and Russian,
NotoNaskhArabic-Regular.ttf for Arabic,
NotoSansCJKjp-Regular.otf for Chinese, Japanese, and Korean. The font file is selected for you in the scripts in grid_scripts based on the source language you pass.

What is the full command (and model) you tried?

long21wt commented 2 years ago

I simply executed the given command, paths were modified, also used NotoSans-Regular.ttf for German, but the issue with cnn is still there.

echo "Ich bin ein robustes Model" | PYTHONPATH=visrep python -m fairseq_cli.interactive ./ --task 'visual_text' --path de-en/checkpoint_best.pt -s de -t en --target-dict de-en/dict.en.txt --beam 5 --image-font-path fairseq/data/visual/fonts/NotoSans-Regular.ttf

This time, the bridge layer is size 420.

2022-01-11 00:20:59 | INFO | fairseq.tasks.visual_text | dictionary size (de-en/dict.en.txt): 10,072
2022-01-11 00:20:59 | INFO | fairseq.data.visual.image_generator | Loading fonts from fairseq/data/visual/fonts/NotoSans-Regular.ttf
2022-01-11 00:20:59 | INFO | fairseq.data.visual.image_generator | Created 8pt NotoSans-Regular.ttf with image height 21 and est. char width 25
2022-01-11 00:20:59 | INFO | fairseq.data.visual.image_generator | Image window size 30 stride 20
2022-01-11 00:20:59 | INFO | fairseq_cli.interactive | loading model(s) from de-en/checkpoint_best.pt
2022-01-11 00:21:00 | INFO | fairseq.modules.visual | 1Layer embedding (norm: True; bridge relu: False) from 20 * 21 = 420 to 512; conv2d kernel size: (3, 3)
Traceback (most recent call last):
  File "/usr/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/home/me/visrep/fairseq_cli/interactive.py", line 318, in <module>
    cli_main()
  File "/home/me/visrep/fairseq_cli/interactive.py", line 314, in cli_main
    distributed_utils.call_main(convert_namespace_to_omegaconf(args), main)
  File "/home/me/visrep/fairseq/distributed/utils.py", line 364, in call_main
    main(cfg, **kwargs)
  File "/home/me/visrep/fairseq_cli/interactive.py", line 147, in main
    models, _model_args = checkpoint_utils.load_model_ensemble(
  File "/home/me/visrep/fairseq/checkpoint_utils.py", line 297, in load_model_ensemble
    ensemble, args, _task = load_model_ensemble_and_task(
  File "/home/me/visrep/fairseq/checkpoint_utils.py", line 358, in load_model_ensemble_and_task
    model.load_state_dict(state["model"], strict=strict, model_cfg=cfg.model)
  File "/home/me/visrep/fairseq/models/fairseq_model.py", line 115, in load_state_dict
    return super().load_state_dict(new_state_dict, strict)
  File "/home/me/visrep/visrep/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1482, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for VisualTextTransformerModel:
        size mismatch for encoder.cnn_embedder.bridge.weight: copying a param with shape torch.Size([512, 440]) from checkpoint, the shape in current model is torch.Size([512, 420]).
esalesky commented 2 years ago

It looks like you're using the TED de-en model and interactive mode is using the default rather than the serialized parameters for some reason; I'll look into that.

If you pass the parameters listed on the README in "Best visual text parameters", there should not be an issue.
The WMT de-en model uses the default parameters and should need only an image-font-path.

For example, for the TED de-en model, this should run as expected:

echo "Ich bin ein robustes Model" | python -m fairseq_cli.interactive ./ --task 'visual_text' --path de-en/checkpoint_best.pt -s de -t en --target-dict de-en/dict.en.txt --beam 5 --image-font-path fairseq/data/visual/fonts/NotoSans-Regular.ttf --image-font-size 10 --image-stride 5 --image-window 20

2022-01-11 08:44:29 | INFO | fairseq.tasks.visual_text | dictionary size (de-en/dict.en.txt): 10,072
2022-01-11 08:44:29 | INFO | fairseq.data.visual.image_generator | Loading fonts from /exp/esalesky/visrep/fairseq-ocr/fairseq/data/visual/fonts/NotoSans-Regular.ttf
2022-01-11 08:44:29 | INFO | fairseq.data.visual.image_generator | Created 10pt NotoSans-Regular.ttf with image height 22 and est. char width 31
2022-01-11 08:44:29 | INFO | fairseq.data.visual.image_generator | Image window size 20 stride 5
2022-01-11 08:44:29 | INFO | fairseq_cli.interactive | loading model(s) from de-en/checkpoint_best.pt
2022-01-11 08:44:30 | INFO | fairseq.modules.visual | 1Layer embedding (norm: True; bridge relu: False) from 20 * 22 = 440 to 512; conv2d kernel size: (3, 3)
2022-01-11 08:44:35 | INFO | fairseq_cli.interactive | NOTE: hypothesis and token scores are output in base 2
2022-01-11 08:44:35 | INFO | fairseq_cli.interactive | Type the input sentence and press return:
W-0 0.087   seconds
H-0 -0.3767355978488922 ▁I ' m ▁a ▁robust ▁model .
D-0 -0.3767355978488922 ▁I ' m ▁a ▁robust ▁model .
P-0 -0.4907 -0.5962 -0.2257 -0.3369 -0.4365 -0.3396 -0.3064 -0.2820
2022-01-11 08:44:35 | INFO | fairseq_cli.interactive | Total time: 6.258 seconds; translation time: 0.087
esalesky commented 2 years ago

I've added an interface so models can be loaded with from_pretrained() in python, and also fixed the bug where the default parameters were being loaded if you passed image_font_path as an argument, so I'm closing this issue.

Here's an example usage (also on the README):

# Download model, spm, and dict files from Zenodo
wget https://zenodo.org/record/5770933/files/de-en.zip
unzip de-en.zip

# Load the model in python
from fairseq.models.visual import VisualTextTransformerModel
model = VisualTextTransformerModel.from_pretrained(
    checkpoint_file='de-en/checkpoint_best.pt',
    target_dict='de-en/dict.en.txt',
    target_spm='de-en/spm_en.model',
    src='de',
    image_font_path='fairseq/data/visual/fonts/NotoSans-Regular.ttf'
)
model.eval()  # disable dropout (or leave in train mode to finetune)

# Translate
model.translate("Das ist ein Test.")
> 'This is a test.'