microsoft / MASS

MASS: Masked Sequence to Sequence Pre-training for Language Generation
https://arxiv.org/pdf/1905.02450.pdf
Other
1.11k stars 206 forks source link

How to implement fine-tuned model by myself? #137

Open kaneyxx opened 4 years ago

kaneyxx commented 4 years ago

Hello, thanks for sharing this awesome project at first :) I have fine-tuned the supNMT pre-trained model and save the checkpoint out there. Now I want to build a model out of MASS directory and implement it on a chatbot to use with my friends. I'm not familiar with fairseq module. How should I do, any suggestion?

e.g. I have the weights but I don't know how to build the model and load it. I found the build_model function in xmasked_seq2seq.py but I don't know how to do next.

kaneyxx commented 4 years ago

I used followed function to load pre-trained model, and got some errors

en2zh = TransformerModel.from_pretrained("./", checkpoint_file="checkpoint_best.pt", task="xmasked_seq2seq", arch="xtransformer", langs="en,zh", source_langs="en", target_langs="zh", mt_steps="en-zh", mass_steps="", memt_steps="", valid_lang_pairs="", no_scale_embedding=True, src_dict="./test/processed/dict.en.txt", tgt_dict="./test/processed/dict.zh.txt" )

<in fairseq 0.71, same as MASS>

NotImplementedError Traceback (most recent call last)

in 12 no_scale_embedding=True, 13 src_dict="./test/processed/dict.en.txt", ---> 14 tgt_dict="./test/processed/dict.zh.txt" 15 ) /opt/conda/lib/python3.7/site-packages/fairseq/models/fairseq_model.py in from_pretrained(cls, model_name_or_path, checkpoint_file, data_name_or_path, **kwargs) 199 print(args) 200 --> 201 return hub_utils.Generator(args, task, models) 202 203 @classmethod /opt/conda/lib/python3.7/site-packages/fairseq/hub_utils.py in __init__(self, args, task, models) 21 self.task = task 22 self.models = models ---> 23 self.src_dict = task.source_dictionary 24 self.tgt_dict = task.target_dictionary 25 self.use_cuda = torch.cuda.is_available() and not getattr(args, 'cpu', False) /opt/conda/lib/python3.7/site-packages/fairseq/tasks/fairseq_task.py in source_dictionary(self) 264 """Return the source :class:`~fairseq.data.Dictionary` (if applicable 265 for this task).""" --> 266 raise NotImplementedError 267 268 @property NotImplementedError: -------------------------------------------------------------------------------------------------------- --------------------------------------------------------------------------- RuntimeError Traceback (most recent call last) in 12 no_scale_embedding=True, 13 src_dict="./test/processed/dict.en.txt", ---> 14 tgt_dict="./test/processed/dict.zh.txt" 15 ) ~/fairseq/fairseq/models/fairseq_model.py in from_pretrained(cls, model_name_or_path, checkpoint_file, data_name_or_path, **kwargs) 216 data_name_or_path, 217 archive_map=cls.hub_models(), --> 218 **kwargs, 219 ) 220 logger.info(x["args"]) ~/fairseq/fairseq/hub_utils.py in from_pretrained(model_name_or_path, checkpoint_file, data_name_or_path, archive_map, **kwargs) 71 models, args, task = checkpoint_utils.load_model_ensemble_and_task( 72 [os.path.join(model_path, cpt) for cpt in checkpoint_file.split(os.pathsep)], ---> 73 arg_overrides=kwargs, 74 ) 75 ~/fairseq/fairseq/checkpoint_utils.py in load_model_ensemble_and_task(filenames, arg_overrides, task, strict, suffix) 209 # build model for ensemble 210 model = task.build_model(args) --> 211 model.load_state_dict(state["model"], strict=strict, args=args) 212 ensemble.append(model) 213 return ensemble, args, task ~/fairseq/fairseq/models/fairseq_model.py in load_state_dict(self, state_dict, strict, args) 91 self.upgrade_state_dict(state_dict) 92 new_state_dict = prune_state_dict(state_dict, args) ---> 93 return super().load_state_dict(new_state_dict, strict) 94 95 def upgrade_state_dict(self, state_dict): ~/miniconda3/envs/torch/lib/python3.7/site-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict) 828 if len(error_msgs) > 0: 829 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( --> 830 self.__class__.__name__, "\n\t".join(error_msgs))) 831 return _IncompatibleKeys(missing_keys, unexpected_keys) 832 RuntimeError: Error(s) in loading state_dict for XTransformerModel: Missing key(s) in state_dict: "decoders.zh.output_projection.weight". ------------------------------------------------------------------------------------------------------- Any help please?