microsoft / ProphetNet

A research project for natural language generation, containing the official implementations by MSRA NLC team.
MIT License
692 stars 110 forks source link

(AR-Diffusion) RuntimeError: Error(s) in loading state_dict for CrossAttention_Diffusion_LM #78

Closed hwaseem04 closed 8 months ago

hwaseem04 commented 8 months ago

I trained the model from scratch using a custom dataset. After training, I utilized one of the weights stored in order to perform inference. But I get the below error:

Error executing job with overrides: ['model.name=bert-base-uncased', 'batch_size=128', 'exp.name=xsum', 'load_step=5000', 'data.name=xsum', 'tgt_len=50', 'max_pos_len=512', 'num_samples=50', 'intermediate_size=2048', 'num_attention_heads=8', 'dropout=0.2', 'in_channels=128', 'out_channels=128', 'time_channels=128', 'skip_sample=True', 'gen_timesteps=20', 'schedule_sampler=xy_uniform', 'time_att=True', 'att_strategy=txl', 'load_from_ema=False', 'prediction=True']
Traceback (most recent call last):
  File "./gen_utils/generate.py", line 136, in main
    model.load_state_dict(model_saved_state.model_dict)
  File "/path/to/miniconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/module.py", line 2153, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for CrossAttention_Diffusion_LM:
        Missing key(s) in state_dict: "transformer_blocks.0.r_w_bias", "transformer_blocks.0.r_r_bias", "transformer_blocks.0.time_trans.0.weight", "transformer_blocks.0.time_trans.0.bias", "transformer_blocks.0.time_trans.2.weight", "transformer_blocks.0.time_trans.2.bias", "transformer_blocks.1.r_w_bias", "transformer_blocks.1.r_r_bias", "transformer_blocks.1.time_trans.0.weight", "transformer_blocks.1.time_trans.0.bias", "transformer_blocks.1.time_trans.2.weight", "transformer_blocks.1.time_trans.2.bias", "transformer_blocks.2.r_w_bias", "transformer_blocks.2.r_r_bias", "transformer_blocks.2.time_trans.0.weight", "transformer_blocks.2.time_trans.0.bias", "transformer_blocks.2.time_trans.2.weight", "transformer_blocks.2.time_trans.2.bias", "transformer_blocks.3.r_w_bias", "transformer_blocks.3.r_r_bias", "transformer_blocks.3.time_trans.0.weight", "transformer_blocks.3.time_trans.0.bias", "transformer_blocks.3.time_trans.2.weight", "transformer_blocks.3.time_trans.2.bias", "transformer_blocks.4.r_w_bias", "transformer_blocks.4.r_r_bias", "transformer_blocks.4.time_trans.0.weight", "transformer_blocks.4.time_trans.0.bias", "transformer_blocks.4.time_trans.2.weight", "transformer_blocks.4.time_trans.2.bias", "transformer_blocks.5.r_w_bias", "transformer_blocks.5.r_r_bias", "transformer_blocks.5.time_trans.0.weight", "transformer_blocks.5.time_trans.0.bias", "transformer_blocks.5.time_trans.2.weight", "transformer_blocks.5.time_trans.2.bias". 
        Unexpected key(s) in state_dict: "time_trans.0.weight", "time_trans.0.bias", "time_trans.2.weight", "time_trans.2.bias". 

My inference script:

 torchrun --nproc_per_node=2 --nnodes=1 ./gen_utils/generate.py model.name='bert-base-uncased' batch_size=128 exp.name=test load_step=5000 data.name=docedit tgt_len=50 max_pos_len=512 num_samples=50 intermediate_size=2048 num_attention_heads=8 dropout=0.2 in_channels=128 out_channels=128 time_channels=128 skip_sample=True gen_timesteps=20 schedule_sampler='xy_uniform' time_att=True att_strategy='txl' load_from_ema=True prediction=True

And also, what is the difference between load_from_ema=True vs load_from_ema=False

any possible direction to debug? @wutong4012

wutong4012 commented 8 months ago

EMA is a commonly used training method (https://www.fidelity.com/learning-center/trading-investing/technical-analysis/technical-indicator-guide/ema)

If you want to use ema during inference, you also need to add this parameter during training. This code will save two models, one without ema and one with ema.

wutong4012 commented 8 months ago

According to your error message, I think it is not caused by ema, but because you added time_trans during training, that is (https://github.com/microsoft/ProphetNet/blob/e2c6657309537b94818f5ddbb2a2c5b5559257bf/AR-diffusion/model_utils/ CrossAttention.py#L31), the corresponding parameter is time_att=True. I want to set time_att=False to run the code. In addition, if you want to add time_att=True during inference, you should also add this parameter during training.

hwaseem04 commented 8 months ago

I want to set time_att=False to run the code.

Perfect, setting time_att=False removes the error during inference. I realised that I didn't use time_att=True during training.