Closed chrisdoyleIE closed 4 years ago
Hi, this happens because we remove the useless optimization history logs from the model to reduce the file size. Only the desired model weights are kept to release. As a result, if you directly load the model, error will be reported that some logs are missed. You can refer to this code with the function model.load_state_dict(states) to load our pretrained weights.
@qiweizhen thank you for the reply.
In the code for build_model
that you linked, model.load_state_dict(states)
is called on model
, such that model
is instantiated with
model = NgramTransformerProphetModel(encoder, decoder)
It follows that both encoder
and decoder
require instantiation, for which src_dict
and tgt_dict
are needed which require the translation_prophetnet
task. Do you have any advice on how to create these dictionaries? Issue continued below.
Alternatively, I tried the below and got a task KeyError
from fairseq import tasks
# Unpickle args, model_state_dict
state = torch.load(f'{MODEL_DIR}/{CHECKPOINT_FILE}')
# Attempt to load task
args = state["args"]
task = tasks.setup_task(args) # error here, KeyError: translation_prophetnet
# Below function creates dictionaries
model = task.build_model(args)
model.load_state_dict(state["model"], strict=strict, args=args)
KeyError Traceback (most recent call last)
<ipython-input-11-947a1ef8f52c> in <module>()
2
3 args = state["args"]
----> 4 task = tasks.setup_task(args)
5 model = task.build_model(args)
6 model.load_state_dict(state["model"], strict=strict, args=args)
/usr/local/lib/python3.6/dist-packages/fairseq/tasks/__init__.py in setup_task(args, **kwargs)
15
16 def setup_task(args, **kwargs):
---> 17 return TASK_REGISTRY[args.task].setup_task(args, **kwargs)
18
19
KeyError: 'translation_prophetnet'
The dictionary we release in this repo is same to BERT-uncased dict. The Fairseq dict object is used in your defined task, for example in translation_prophetnet task If you want to build your own dict, bpe algorithm should be used, which you can refer to code.
For your second KeyError problem, did you set --user-dir correctly? Actually the models and tasks under your --user-dir are added into Fairseq with the @regitster_xxx function, for example.
I have it working correctly now, I had not set up the task correctly with @register_task.
Thank you !
That code is no longer available with the link you provided. You could please tell me where I can find it?
Thanks!
Hi, this happens because we remove the useless optimization history logs from the model to reduce the file size. Only the desired model weights are kept to release. As a result, if you directly load the model, error will be reported that some logs are missed. You can refer to [this code](https://github.com/microsoft/ProphetNet/blob/master/src/prophetnet/ngram_s2s_model.py#L146) with the function model.load_state_dict(states) to load our pretrained weights.
@chrisdoyleIE How do you solve the problem that I encountered the same error thanks
That code is no longer available with the link you provided. You could please tell me where I can find it?
Thanks!
Hi, this happens because we remove the useless optimization history logs from the model to reduce the file size. Only the desired model weights are kept to release. As a result, if you directly load the model, error will be reported that some logs are missed. You can refer to [this code](https://github.com/microsoft/ProphetNet/blob/master/src/prophetnet/ngram_s2s_model.py#L146) with the function model.load_state_dict(states) to load our pretrained weights.
https://github.com/microsoft/ProphetNet/blob/master/ProphetNet_Code/prophetnet/ngram_s2s_model.py
Hi guys,
Thank you for the incredible work.
I tried to load this model from the larger checkpoint in the following manner:
but was presented with a key error:
Versions fairseq==0.9.0 torch==1.4.0
Any advice on how to proceed would be greatly appreciated, I wish to load ProphetNet into a fairseq model so I can adapt the architecture to a custom task.