huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
135.05k stars 27.02k forks source link

Converting fairseq NMT to transformers misses model weight #10298

Closed tagucci closed 3 years ago

tagucci commented 3 years ago

Hi there, question about fairseq NMT model (FSMT) conversion.

I tried to convert my own fairseq-nmt model (transformer_wmt_en_de) based on this conversion script. However, decoder.embed_out weight is missing after converting fairseq model to transformers FSMT model. This parameter exists when not specifing --share-all-embeddings or --share-decoder-input-output-embed, while official fairseq wmt models do not have decoder.embed_out weight because specifying --share-all-embedding. https://github.com/pytorch/fairseq/issues/2537

Are there any solution or tips to converting own fairseq model?

NielsRogge commented 3 years ago

Pinging @stas00 here

stas00 commented 3 years ago

Thank you for the ping, @NielsRogge

@tagucci, when you file an issue you will find a list of who to tag for what topic, so please use it to tag the right people. Otherwise it's hard for everybody to try to follow all issues.

also when you link to a line of code in github, always hit y first to get the exact sha (it rewrites the url to embed the current git sha). Otherwise your links quickly become invalid, e.g. I have no idea where you were trying to link to in your link to transformer_wmt_en_de as the code was modified today.


OK, could you first clarify where do you get "decoder.embed_out weight is missing" - the command line and the backtrace please. Also a dump of the model (i.e. print(model).

Now to the guess work.

Does your model miss output_projection weight key?

The context is here: https://github.com/pytorch/fairseq/blob/ab560669cd9baaa4009e1fd01c970f8ffccd1ee0/fairseq/models/transformer.py#L950-L960

fairseq has different versions of their code, and some have keys renamed or added, that's why they have all that logic.

You can see that it's a simple alias - i.e. in fsmt decoder embed and output are always shared.

https://github.com/huggingface/transformers/blob/461e8cacf94d1f76367cc9ba2cfd5b9bd3641c81/src/transformers/models/fsmt/modeling_fsmt.py#L651

So if it's missing you can assign it in the conversion script:

    model_state_dict["model.decoder.output_projection.weight"] = model_state_dict["model.decoder.embed_tokens.weight"]

add this to this line: https://github.com/huggingface/transformers/blob/461e8cacf94d1f76367cc9ba2cfd5b9bd3641c81/src/transformers/models/fsmt/convert_fsmt_original_pytorch_checkpoint_to_pytorch.py#L247

but again I could have guessed wrong and will need to see the model dump to tell you more.

You can see the dump of original model I converted from here: https://github.com/stas00/porting/blob/master/transformers/fairseq-wmt19/nbs/config.ipynb

tagucci commented 3 years ago

@NielsRogge Thanks for pinging @stas00! @stas00 Sorry for the inconvenience of linking the code. Following your advice, my model args and model dump are as below.

in fsmt decoder embed and output are always shared.

As you said, fsmt does not have decoder embed and output seperately, my fairseq transformer_wmt_en_de without share_decoder_input_output_embed cannot fit fsmt in transformers. In this case, do I need to retrain fairseq model with share_decoder_input_output_embed or modify FSMTDecoer?

import torch
from pprint import pprint
chkpt = torch.load("model/checkpoint_best.pt")
model = chkpt["model"]
pprint(vars(chkpt["args"]))
print("\n".join(model.keys()))
# args
{'activation_dropout': 0.0,
 'activation_fn': 'relu',
 'adam_betas': '(0.9, 0.98)',
 'adam_eps': 1e-08,
 'adaptive_input': False,
 'adaptive_softmax_cutoff': None,
 'adaptive_softmax_dropout': 0,
 'arch': 'transformer_wmt_en_de',
 'attention_dropout': 0.0,
 'best_checkpoint_metric': 'loss',
 'bpe': None,
 'bucket_cap_mb': 25,
 'clip_norm': 0.0,
 'cpu': False,
 'criterion': 'label_smoothed_cross_entropy',
 'cross_self_attention': False,
 'curriculum': 0,
 'data': './data/data.src_trg',
 'dataset_impl': None,
 'ddp_backend': 'no_c10d',
 'decoder_attention_heads': 8,
 'decoder_embed_dim': 512,
 'decoder_embed_path': None,
 'decoder_ffn_embed_dim': 2048,
 'decoder_input_dim': 512,
 'decoder_layerdrop': 0,
 'decoder_layers': 6,
 'decoder_layers_to_keep': None,
 'decoder_learned_pos': False,
 'decoder_normalize_before': False,
 'decoder_output_dim': 512,
 'device_id': 0,
 'disable_validation': False,
 'distributed_backend': 'nccl',
 'distributed_init_method': 'tcp://localhost:16441',
 'distributed_no_spawn': False,
 'distributed_port': -1,
 'distributed_rank': 0,
 'distributed_world_size': 4,
 'dropout': 0.1,
 'empty_cache_freq': 0,
 'encoder_attention_heads': 8,
 'encoder_embed_dim': 512,
 'encoder_embed_path': None,
 'encoder_ffn_embed_dim': 2048,
 'encoder_layerdrop': 0,
 'encoder_layers': 6,
 'encoder_layers_to_keep': None,
 'encoder_learned_pos': False,
 'encoder_normalize_before': False,
 'fast_stat_sync': False,
 'find_unused_parameters': False,
 'fix_batches_to_gpus': False,
 'fixed_validation_seed': None,
 'fp16': False,
 'fp16_init_scale': 128,
 'fp16_scale_tolerance': 0.0,
 'fp16_scale_window': None,
 'keep_interval_updates': 20,
 'keep_last_epochs': -1,
 'label_smoothing': 0.1,
 'layer_wise_attention': False,
 'layernorm_embedding': False,
 'lazy_load': False,
 'left_pad_source': True,
 'left_pad_target': False,
 'load_alignments': False,
 'log_format': 'json',
 'log_interval': 50,
 'lr': [0.0007],
 'lr_scheduler': 'inverse_sqrt',
 'max_epoch': 100,
 'max_sentences': None,
 'max_sentences_valid': None,
 'max_source_positions': 1024,
 'max_target_positions': 1024,
 'max_tokens': 4096,
 'max_tokens_valid': 4096,
 'max_update': 0,
 'maximize_best_checkpoint_metric': False,
 'memory_efficient_fp16': False,
 'min_loss_scale': 0.0001,
 'min_lr': 1e-09,
 'no_cross_attention': False,
 'no_epoch_checkpoints': True,
 'no_last_checkpoints': False,
 'no_progress_bar': True,
 'no_save': False,
 'no_save_optimizer_state': False,
 'no_scale_embedding': False,
 'no_token_positional_embeddings': False,
 'num_workers': 1,
 'optimizer': 'adam',
 'optimizer_overrides': '{}',
 'raw_text': False,
 'required_batch_size_multiple': 8,
 'reset_dataloader': False,
 'reset_lr_scheduler': False,
 'reset_meters': False,
 'reset_optimizer': False,
 'restore_file': 'checkpoint_last.pt',
 'save_dir': './data/models',
 'save_interval': 1,
 'save_interval_updates': 1000,
 'seed': 1,
 'sentence_avg': False,
 'share_all_embeddings': False,
 'share_decoder_input_output_embed': False,
 'skip_invalid_size_inputs_valid_test': True,
 'source_lang': 'src',
 'target_lang': 'trg',
 'task': 'translation',
 'tensorboard_logdir': '',
 'threshold_loss_scale': None,
 'tokenizer': None,
 'train_subset': 'train',
 'truncate_source': False,
 'update_freq': [16],
 'upsample_primary': 1,
 'use_bmuf': False,
 'user_dir': None,
 'valid_subset': 'valid',
 'validate_interval': 1,
 'warmup_init_lr': 1e-07,
 'warmup_updates': 4000,
 'weight_decay': 0.0}
# model dump
encoder.version
encoder.embed_tokens.weight
encoder.embed_positions._float_tensor
encoder.layers.0.self_attn.k_proj.weight
encoder.layers.0.self_attn.k_proj.bias
encoder.layers.0.self_attn.v_proj.weight
encoder.layers.0.self_attn.v_proj.bias
encoder.layers.0.self_attn.q_proj.weight
encoder.layers.0.self_attn.q_proj.bias
encoder.layers.0.self_attn.out_proj.weight
encoder.layers.0.self_attn.out_proj.bias
encoder.layers.0.self_attn_layer_norm.weight
encoder.layers.0.self_attn_layer_norm.bias
encoder.layers.0.fc1.weight
encoder.layers.0.fc1.bias
encoder.layers.0.fc2.weight
encoder.layers.0.fc2.bias
encoder.layers.0.final_layer_norm.weight
encoder.layers.0.final_layer_norm.bias
encoder.layers.1.self_attn.k_proj.weight
encoder.layers.1.self_attn.k_proj.bias
encoder.layers.1.self_attn.v_proj.weight
encoder.layers.1.self_attn.v_proj.bias
encoder.layers.1.self_attn.q_proj.weight
encoder.layers.1.self_attn.q_proj.bias
encoder.layers.1.self_attn.out_proj.weight
encoder.layers.1.self_attn.out_proj.bias
encoder.layers.1.self_attn_layer_norm.weight
encoder.layers.1.self_attn_layer_norm.bias
encoder.layers.1.fc1.weight
encoder.layers.1.fc1.bias
encoder.layers.1.fc2.weight
encoder.layers.1.fc2.bias
encoder.layers.1.final_layer_norm.weight
encoder.layers.1.final_layer_norm.bias
encoder.layers.2.self_attn.k_proj.weight
encoder.layers.2.self_attn.k_proj.bias
encoder.layers.2.self_attn.v_proj.weight
encoder.layers.2.self_attn.v_proj.bias
encoder.layers.2.self_attn.q_proj.weight
encoder.layers.2.self_attn.q_proj.bias
encoder.layers.2.self_attn.out_proj.weight
encoder.layers.2.self_attn.out_proj.bias
encoder.layers.2.self_attn_layer_norm.weight
encoder.layers.2.self_attn_layer_norm.bias
encoder.layers.2.fc1.weight
encoder.layers.2.fc1.bias
encoder.layers.2.fc2.weight
encoder.layers.2.fc2.bias
encoder.layers.2.final_layer_norm.weight
encoder.layers.2.final_layer_norm.bias
encoder.layers.3.self_attn.k_proj.weight
encoder.layers.3.self_attn.k_proj.bias
encoder.layers.3.self_attn.v_proj.weight
encoder.layers.3.self_attn.v_proj.bias
encoder.layers.3.self_attn.q_proj.weight
encoder.layers.3.self_attn.q_proj.bias
encoder.layers.3.self_attn.out_proj.weight
encoder.layers.3.self_attn.out_proj.bias
encoder.layers.3.self_attn_layer_norm.weight
encoder.layers.3.self_attn_layer_norm.bias
encoder.layers.3.fc1.weight
encoder.layers.3.fc1.bias
encoder.layers.3.fc2.weight
encoder.layers.3.fc2.bias
encoder.layers.3.final_layer_norm.weight
encoder.layers.3.final_layer_norm.bias
encoder.layers.4.self_attn.k_proj.weight
encoder.layers.4.self_attn.k_proj.bias
encoder.layers.4.self_attn.v_proj.weight
encoder.layers.4.self_attn.v_proj.bias
encoder.layers.4.self_attn.q_proj.weight
encoder.layers.4.self_attn.q_proj.bias
encoder.layers.4.self_attn.out_proj.weight
encoder.layers.4.self_attn.out_proj.bias
encoder.layers.4.self_attn_layer_norm.weight
encoder.layers.4.self_attn_layer_norm.bias
encoder.layers.4.fc1.weight
encoder.layers.4.fc1.bias
encoder.layers.4.fc2.weight
encoder.layers.4.fc2.bias
encoder.layers.4.final_layer_norm.weight
encoder.layers.4.final_layer_norm.bias
encoder.layers.5.self_attn.k_proj.weight
encoder.layers.5.self_attn.k_proj.bias
encoder.layers.5.self_attn.v_proj.weight
encoder.layers.5.self_attn.v_proj.bias
encoder.layers.5.self_attn.q_proj.weight
encoder.layers.5.self_attn.q_proj.bias
encoder.layers.5.self_attn.out_proj.weight
encoder.layers.5.self_attn.out_proj.bias
encoder.layers.5.self_attn_layer_norm.weight
encoder.layers.5.self_attn_layer_norm.bias
encoder.layers.5.fc1.weight
encoder.layers.5.fc1.bias
encoder.layers.5.fc2.weight
encoder.layers.5.fc2.bias
encoder.layers.5.final_layer_norm.weight
encoder.layers.5.final_layer_norm.bias
decoder.embed_out
decoder.version
decoder.embed_tokens.weight
decoder.embed_positions._float_tensor
decoder.layers.0.self_attn.k_proj.weight
decoder.layers.0.self_attn.k_proj.bias
decoder.layers.0.self_attn.v_proj.weight
decoder.layers.0.self_attn.v_proj.bias
decoder.layers.0.self_attn.q_proj.weight
decoder.layers.0.self_attn.q_proj.bias
decoder.layers.0.self_attn.out_proj.weight
decoder.layers.0.self_attn.out_proj.bias
decoder.layers.0.self_attn_layer_norm.weight
decoder.layers.0.self_attn_layer_norm.bias
decoder.layers.0.encoder_attn.k_proj.weight
decoder.layers.0.encoder_attn.k_proj.bias
decoder.layers.0.encoder_attn.v_proj.weight
decoder.layers.0.encoder_attn.v_proj.bias
decoder.layers.0.encoder_attn.q_proj.weight
decoder.layers.0.encoder_attn.q_proj.bias
decoder.layers.0.encoder_attn.out_proj.weight
decoder.layers.0.encoder_attn.out_proj.bias
decoder.layers.0.encoder_attn_layer_norm.weight
decoder.layers.0.encoder_attn_layer_norm.bias
decoder.layers.0.fc1.weight
decoder.layers.0.fc1.bias
decoder.layers.0.fc2.weight
decoder.layers.0.fc2.bias
decoder.layers.0.final_layer_norm.weight
decoder.layers.0.final_layer_norm.bias
decoder.layers.1.self_attn.k_proj.weight
decoder.layers.1.self_attn.k_proj.bias
decoder.layers.1.self_attn.v_proj.weight
decoder.layers.1.self_attn.v_proj.bias
decoder.layers.1.self_attn.q_proj.weight
decoder.layers.1.self_attn.q_proj.bias
decoder.layers.1.self_attn.out_proj.weight
decoder.layers.1.self_attn.out_proj.bias
decoder.layers.1.self_attn_layer_norm.weight
decoder.layers.1.self_attn_layer_norm.bias
decoder.layers.1.encoder_attn.k_proj.weight
decoder.layers.1.encoder_attn.k_proj.bias
decoder.layers.1.encoder_attn.v_proj.weight
decoder.layers.1.encoder_attn.v_proj.bias
decoder.layers.1.encoder_attn.q_proj.weight
decoder.layers.1.encoder_attn.q_proj.bias
decoder.layers.1.encoder_attn.out_proj.weight
decoder.layers.1.encoder_attn.out_proj.bias
decoder.layers.1.encoder_attn_layer_norm.weight
decoder.layers.1.encoder_attn_layer_norm.bias
decoder.layers.1.fc1.weight
decoder.layers.1.fc1.bias
decoder.layers.1.fc2.weight
decoder.layers.1.fc2.bias
decoder.layers.1.final_layer_norm.weight
decoder.layers.1.final_layer_norm.bias
decoder.layers.2.self_attn.k_proj.weight
decoder.layers.2.self_attn.k_proj.bias
decoder.layers.2.self_attn.v_proj.weight
decoder.layers.2.self_attn.v_proj.bias
decoder.layers.2.self_attn.q_proj.weight
decoder.layers.2.self_attn.q_proj.bias
decoder.layers.2.self_attn.out_proj.weight
decoder.layers.2.self_attn.out_proj.bias
decoder.layers.2.self_attn_layer_norm.weight
decoder.layers.2.self_attn_layer_norm.bias
decoder.layers.2.encoder_attn.k_proj.weight
decoder.layers.2.encoder_attn.k_proj.bias
decoder.layers.2.encoder_attn.v_proj.weight
decoder.layers.2.encoder_attn.v_proj.bias
decoder.layers.2.encoder_attn.q_proj.weight
decoder.layers.2.encoder_attn.q_proj.bias
decoder.layers.2.encoder_attn.out_proj.weight
decoder.layers.2.encoder_attn.out_proj.bias
decoder.layers.2.encoder_attn_layer_norm.weight
decoder.layers.2.encoder_attn_layer_norm.bias
decoder.layers.2.fc1.weight
decoder.layers.2.fc1.bias
decoder.layers.2.fc2.weight
decoder.layers.2.fc2.bias
decoder.layers.2.final_layer_norm.weight
decoder.layers.2.final_layer_norm.bias
decoder.layers.3.self_attn.k_proj.weight
decoder.layers.3.self_attn.k_proj.bias
decoder.layers.3.self_attn.v_proj.weight
decoder.layers.3.self_attn.v_proj.bias
decoder.layers.3.self_attn.q_proj.weight
decoder.layers.3.self_attn.q_proj.bias
decoder.layers.3.self_attn.out_proj.weight
decoder.layers.3.self_attn.out_proj.bias
decoder.layers.3.self_attn_layer_norm.weight
decoder.layers.3.self_attn_layer_norm.bias
decoder.layers.3.encoder_attn.k_proj.weight
decoder.layers.3.encoder_attn.k_proj.bias
decoder.layers.3.encoder_attn.v_proj.weight
decoder.layers.3.encoder_attn.v_proj.bias
decoder.layers.3.encoder_attn.q_proj.weight
decoder.layers.3.encoder_attn.q_proj.bias
decoder.layers.3.encoder_attn.out_proj.weight
decoder.layers.3.encoder_attn.out_proj.bias
decoder.layers.3.encoder_attn_layer_norm.weight
decoder.layers.3.encoder_attn_layer_norm.bias
decoder.layers.3.fc1.weight
decoder.layers.3.fc1.bias
decoder.layers.3.fc2.weight
decoder.layers.3.fc2.bias
decoder.layers.3.final_layer_norm.weight
decoder.layers.3.final_layer_norm.bias
decoder.layers.4.self_attn.k_proj.weight
decoder.layers.4.self_attn.k_proj.bias
decoder.layers.4.self_attn.v_proj.weight
decoder.layers.4.self_attn.v_proj.bias
decoder.layers.4.self_attn.q_proj.weight
decoder.layers.4.self_attn.q_proj.bias
decoder.layers.4.self_attn.out_proj.weight
decoder.layers.4.self_attn.out_proj.bias
decoder.layers.4.self_attn_layer_norm.weight
decoder.layers.4.self_attn_layer_norm.bias
decoder.layers.4.encoder_attn.k_proj.weight
decoder.layers.4.encoder_attn.k_proj.bias
decoder.layers.4.encoder_attn.v_proj.weight
decoder.layers.4.encoder_attn.v_proj.bias
decoder.layers.4.encoder_attn.q_proj.weight
decoder.layers.4.encoder_attn.q_proj.bias
decoder.layers.4.encoder_attn.out_proj.weight
decoder.layers.4.encoder_attn.out_proj.bias
decoder.layers.4.encoder_attn_layer_norm.weight
decoder.layers.4.encoder_attn_layer_norm.bias
decoder.layers.4.fc1.weight
decoder.layers.4.fc1.bias
decoder.layers.4.fc2.weight
decoder.layers.4.fc2.bias
decoder.layers.4.final_layer_norm.weight
decoder.layers.4.final_layer_norm.bias
decoder.layers.5.self_attn.k_proj.weight
decoder.layers.5.self_attn.k_proj.bias
decoder.layers.5.self_attn.v_proj.weight
decoder.layers.5.self_attn.v_proj.bias
decoder.layers.5.self_attn.q_proj.weight
decoder.layers.5.self_attn.q_proj.bias
decoder.layers.5.self_attn.out_proj.weight
decoder.layers.5.self_attn.out_proj.bias
decoder.layers.5.self_attn_layer_norm.weight
decoder.layers.5.self_attn_layer_norm.bias
decoder.layers.5.encoder_attn.k_proj.weight
decoder.layers.5.encoder_attn.k_proj.bias
decoder.layers.5.encoder_attn.v_proj.weight
decoder.layers.5.encoder_attn.v_proj.bias
decoder.layers.5.encoder_attn.q_proj.weight
decoder.layers.5.encoder_attn.q_proj.bias
decoder.layers.5.encoder_attn.out_proj.weight
decoder.layers.5.encoder_attn.out_proj.bias
decoder.layers.5.encoder_attn_layer_norm.weight
decoder.layers.5.encoder_attn_layer_norm.bias
decoder.layers.5.fc1.weight
decoder.layers.5.fc1.bias
decoder.layers.5.fc2.weight
decoder.layers.5.fc2.bias
decoder.layers.5.final_layer_norm.weight
decoder.layers.5.final_layer_norm.bias
stas00 commented 3 years ago

Thank you for the model dump, so my guess was correct - it's missing output_projection and I gave you the solution at the end of my previous comment.

I still don't know what the error you get, when and the backtrace, but perhaps my guessed solution is all you need.

But no, you don't need to re-train.

if it works could you adapt the script to check if the checkpoint that is being loaded doesn't have this key and if so to copy it as I suggested?

tagucci commented 3 years ago

@stas00 Running convert_fsmt_original_pytorch_checkpoint_to_pytorch.py is successful, but there is something wrong. In comparing fairseq model provided by torch.hub and converted HF model, the translation result is matched.

from transformers import FSMTForConditionalGeneration, FSMTTokenizer, TranslationPipeline
import torch

input_text = "Machine learning is great!"
# fairseq
en2de = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.en-de', checkpoint_file='model1.pt:model2.pt:model3.pt:model4.pt',
                       tokenizer='moses', bpe='fastbpe')
fairseq_res = en2de.translate(input_text)
# tranformers
fsmt_path = "./fairseq2hf/data/wmt19-en-de/"
tokenizer = FSMTTokenizer.from_pretrained(fsmt_path)
model = FSMTForConditionalGeneration.from_pretrained(fsmt_path)
nlp = TranslationPipeline(model=model, tokenizer=tokenizer)
fsmt_res = nlp(input_text)[0]["translation_text"]

print("fairseq: {}".format(fairseq_res))
print("transformer: {}".format(fsmt_res))
print("match: {}".format(fairseq_res == fsmt_res))
"""
fairseq: Maschinelles Lernen ist großartig!
transformer: Maschinelles Lernen ist großartig!
match: True
"""

However, my fairseq model and converted HF model have wrong result with same parameter (beam_size=5). Do you have any idea to debug why tranlation results are different?

fairseq result

# encoded token by hypo_token by fairseq-interactive
tensor([[5269, 2069,    5, 1154,    9,    4, 1823, 3382,    5, 3128,  116,  167,
         1582,    7, 2192,  914,   63,    6, 1823, 2807,  124, 1219, 1106,    8,
           53, 2175, 2007,  483,    4,  660,  708, 5229,   33,   44,    4, 6049,
         1430,    5, 1806, 2050, 2282, 1908,    4,  334, 3229, 4808, 6102,    5,
         5031,   11,    5,  291, 4214, 6485,   10, 5784, 1908,   23, 1765, 4916,
            6,    2]])

# hypo_token by fairseq-interactive
tensor([ 924, 4938,    6, 3056,   59,  503, 1497,    4, 5835,  847,    6,  592,
           2], dtype=torch.int32)

transformers result

encoded_token = torch.tensor([[5269, 2069, 5, 1154, 9, 4, 1823, 3382, 5, 3128, 116, 167, 1582, 7, 2192, 914, 63, 6, 1823, 2807, 124, 1219, 1106, 8, 53, 2175, 2007, 483, 4, 660, 708, 5229, 33, 44, 4, 6049, 1430, 5, 1806, 2050, 2282, 1908, 4, 334, 3229, 4808, 6102, 5, 5031, 11, 5, 291, 4214, 6485, 10, 5784, 1908, 23, 1765, 4916, 6, 2]])

fsmt = FSMTForConditionalGeneration.from_pretrained("./fairseq2HF/")
hypo = fsmt.generate(encoded_token, num_beams=5)
print(hypo)
# tensor([[ 2, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21,  2]])
stas00 commented 3 years ago

I'm a bit lost - we were discussing a missing state dict key, now we are discussing invalid translation.

Did my suggestion help to resolve the problem of the missing key and now you're presenting the next issue?

Wrt to your transformers result with your model, do you get any better behavior if you encode the tokens via transformers and then feed it to generate? perhaps the dict has somehow changed? though a repeated 21 is suspiciously bad.

tagucci commented 3 years ago

@stas00

Did my suggestion help to resolve the problem of the missing key and now you're presenting the next issue?

Yes, thanks for the helpful comments. Sorry, I should post it as another issue.

do you get any better behavior if you encode the tokens via transformers and then feed it to generate?

I do not use transformers tokenizer because my fairseq model has a different vocab size, and it's impossible to encode/decode by a single tokenizer model. Converting token to id is used by fairseq's Dictionary. I'll post another issue if necessary after scrutinizing my code.

Thanks for the big help!

stas00 commented 3 years ago

Thank you for clarifying that your original issue has been resolved. Please feel free to close this issue when you feel it's working for you.

Based on your comments, I'm concerned about 2 things:

  1. your different dictionaries - a model has to come with the exact dict it was trained on, after conversion too. So it sounds that something isn't right there. If you're not sure what's happening perhaps try to clarify how it came to be that your fairseq model has a different vocab size.
  2. perhaps that output_projection layer is getting in the way of your model if it was trained without it. You could try to hunt down the few lines where it's used in the code and and bypass it and test whether your translation works then. If you're comfortable editing the source code that is.