neulab / guided_summarization

GSum: A General Framework for Guided Neural Abstractive Summarization
MIT License
113 stars 27 forks source link

Train bart.large on custom dataset from the begining #37

Open JaniceXiong opened 2 years ago

JaniceXiong commented 2 years ago

Hi, thank you for releasing trained model. But if I want to train bart.large on my custom dataset from the beginning, and set model_path to fairseq bart.large, raise the exception like below. And it seems that the exception is caused by "architectures dismatch". And I want to know where I should make a change to initialize these parameters which I did not find in z_train.sh you provided.

RuntimeError: Error(s) in loading state_dict for GuidedBARTModel: Missing key(s) in state_dict: "encoder.layers.12.self_attn.k_proj.weight", "encoder.layers.12.self_attn.k_proj.bias", "encoder.layers.12.self_attn.v_proj.weight", "encoder.layers.12.self_attn.v_proj.bias", "encoder.layers.12.self_attn.q_proj.weight", "encoder.layers.12.self_attn.q_proj.bias", "encoder.layers.12.self_attn.out_proj.weight", "encoder.layers.12.self_attn.out_proj.bias", "encoder.layers.12.self_attn_layer_norm.weight", "encoder.layers.12.self_attn_layer_norm.bias", "encoder.layers.12.fc1.weight", "encoder.layers.12.fc1.bias", "encoder.layers.12.fc2.weight", "encoder.layers.12.fc2.bias", "encoder.layers.12.final_layer_norm.weight", "encoder.layers.12.final_layer_norm.bias", "decoder.layers.0.z_encoder_attn.k_proj.weight", "decoder.layers.0.z_encoder_attn.k_proj.bias", "decoder.layers.0.z_encoder_attn.v_proj.weight", "decoder.layers.0.z_encoder_attn.v_proj.bias", "decoder.layers.0.z_encoder_attn.q_proj.weight", "decoder.layers.0.z_encoder_attn.q_proj.bias", "decoder.layers.0.z_encoder_attn.out_proj.weight", "decoder.layers.0.z_encoder_attn.out_proj.bias", "decoder.layers.0.z_encoder_attn_layer_norm.weight", "decoder.layers.0.z_encoder_attn_layer_norm.bias", "decoder.layers.1.z_encoder_attn.k_proj.weight", "decoder.layers.1.z_encoder_attn.k_proj.bias", "decoder.layers.1.z_encoder_attn.v_proj.weight", "decoder.layers.1.z_encoder_attn.v_proj.bias", "decoder.layers.1.z_encoder_attn.q_proj.weight", "decoder.layers.1.z_encoder_attn.q_proj.bias", "decoder.layers.1.z_encoder_attn.out_proj.weight", "decoder.layers.1.z_encoder_attn.out_proj.bias", "decoder.layers.1.z_encoder_attn_layer_norm.weight", "decoder.layers.1.z_encoder_attn_layer_norm.bias", "decoder.layers.2.z_encoder_attn.k_proj.weight", "decoder.layers.2.z_encoder_attn.k_proj.bias", "decoder.layers.2.z_encoder_attn.v_proj.weight", "decoder.layers.2.z_encoder_attn.v_proj.bias", "decoder.layers.2.z_encoder_attn.q_proj.weight", "decoder.layers.2.z_encoder_attn.q_proj.bias", "decoder.layers.2.z_encoder_attn.out_proj.weight", "decoder.layers.2.z_encoder_attn.out_proj.bias", "decoder.layers.2.z_encoder_attn_layer_norm.weight", "decoder.layers.2.z_encoder_attn_layer_norm.bias", "decoder.layers.3.z_encoder_attn.k_proj.weight", "decoder.layers.3.z_encoder_attn.k_proj.bias", "decoder.layers.3.z_encoder_attn.v_proj.weight", "decoder.layers.3.z_encoder_attn.v_proj.bias", "decoder.layers.3.z_encoder_attn.q_proj.weight", "decoder.layers.3.z_encoder_attn.q_proj.bias", "decoder.layers.3.z_encoder_attn.out_proj.weight", "decoder.layers.3.z_encoder_attn.out_proj.bias", "decoder.layers.3.z_encoder_attn_layer_norm.weight", "decoder.layers.3.z_encoder_attn_layer_norm.bias", "decoder.layers.4.z_encoder_attn.k_proj.weight", "decoder.layers.4.z_encoder_attn.k_proj.bias", "decoder.layers.4.z_encoder_attn.v_proj.weight", "decoder.layers.4.z_encoder_attn.v_proj.bias", "decoder.layers.4.z_encoder_attn.q_proj.weight", "decoder.layers.4.z_encoder_attn.q_proj.bias", "decoder.layers.4.z_encoder_attn.out_proj.weight", "decoder.layers.4.z_encoder_attn.out_proj.bias", "decoder.layers.4.z_encoder_attn_layer_norm.weight", "decoder.layers.4.z_encoder_attn_layer_norm.bias", "decoder.layers.5.z_encoder_attn.k_proj.weight", "decoder.layers.5.z_encoder_attn.k_proj.bias", "decoder.layers.5.z_encoder_attn.v_proj.weight", "decoder.layers.5.z_encoder_attn.v_proj.bias", "decoder.layers.5.z_encoder_attn.q_proj.weight", "decoder.layers.5.z_encoder_attn.q_proj.bias", "decoder.layers.5.z_encoder_attn.out_proj.weight", "decoder.layers.5.z_encoder_attn.out_proj.bias", "decoder.layers.5.z_encoder_attn_layer_norm.weight", "decoder.layers.5.z_encoder_attn_layer_norm.bias", "decoder.layers.6.z_encoder_attn.k_proj.weight", "decoder.layers.6.z_encoder_attn.k_proj.bias", "decoder.layers.6.z_encoder_attn.v_proj.weight", "decoder.layers.6.z_encoder_attn.v_proj.bias", "decoder.layers.6.z_encoder_attn.q_proj.weight", "decoder.layers.6.z_encoder_attn.q_proj.bias", "decoder.layers.6.z_encoder_attn.out_proj.weight", "decoder.layers.6.z_encoder_attn.out_proj.bias", "decoder.layers.6.z_encoder_attn_layer_norm.weight", "decoder.layers.6.z_encoder_attn_layer_norm.bias", "decoder.layers.7.z_encoder_attn.k_proj.weight", "decoder.layers.7.z_encoder_attn.k_proj.bias", "decoder.layers.7.z_encoder_attn.v_proj.weight", "decoder.layers.7.z_encoder_attn.v_proj.bias", "decoder.layers.7.z_encoder_attn.q_proj.weight", "decoder.layers.7.z_encoder_attn.q_proj.bias", "decoder.layers.7.z_encoder_attn.out_proj.weight", "decoder.layers.7.z_encoder_attn.out_proj.bias", "decoder.layers.7.z_encoder_attn_layer_norm.weight", "decoder.layers.7.z_encoder_attn_layer_norm.bias", "decoder.layers.8.z_encoder_attn.k_proj.weight", "decoder.layers.8.z_encoder_attn.k_proj.bias", "decoder.layers.8.z_encoder_attn.v_proj.weight", "decoder.layers.8.z_encoder_attn.v_proj.bias", "decoder.layers.8.z_encoder_attn.q_proj.weight", "decoder.layers.8.z_encoder_attn.q_proj.bias", "decoder.layers.8.z_encoder_attn.out_proj.weight", "decoder.layers.8.z_encoder_attn.out_proj.bias", "decoder.layers.8.z_encoder_attn_layer_norm.weight", "decoder.layers.8.z_encoder_attn_layer_norm.bias", "decoder.layers.9.z_encoder_attn.k_proj.weight", "decoder.layers.9.z_encoder_attn.k_proj.bias", "decoder.layers.9.z_encoder_attn.v_proj.weight", "decoder.layers.9.z_encoder_attn.v_proj.bias", "decoder.layers.9.z_encoder_attn.q_proj.weight", "decoder.layers.9.z_encoder_attn.q_proj.bias", "decoder.layers.9.z_encoder_attn.out_proj.weight", "decoder.layers.9.z_encoder_attn.out_proj.bias", "decoder.layers.9.z_encoder_attn_layer_norm.weight", "decoder.layers.9.z_encoder_attn_layer_norm.bias", "decoder.layers.10.z_encoder_attn.k_proj.weight", "decoder.layers.10.z_encoder_attn.k_proj.bias", "decoder.layers.10.z_encoder_attn.v_proj.weight", "decoder.layers.10.z_encoder_attn.v_proj.bias", "decoder.layers.10.z_encoder_attn.q_proj.weight", "decoder.layers.10.z_encoder_attn.q_proj.bias", "decoder.layers.10.z_encoder_attn.out_proj.weight", "decoder.layers.10.z_encoder_attn.out_proj.bias", "decoder.layers.10.z_encoder_attn_layer_norm.weight", "decoder.layers.10.z_encoder_attn_layer_norm.bias", "decoder.layers.11.z_encoder_attn.k_proj.weight", "decoder.layers.11.z_encoder_attn.k_proj.bias", "decoder.layers.11.z_encoder_attn.v_proj.weight", "decoder.layers.11.z_encoder_attn.v_proj.bias", "decoder.layers.11.z_encoder_attn.q_proj.weight", "decoder.layers.11.z_encoder_attn.q_proj.bias", "decoder.layers.11.z_encoder_attn.out_proj.weight", "decoder.layers.11.z_encoder_attn.out_proj.bias", "decoder.layers.11.z_encoder_attn_layer_norm.weight", "decoder.layers.11.z_encoder_attn_layer_norm.bias".

During handling of the above exception, another exception occurred:

Traceback (most recent call last): File "/home/xjw/miniconda3/envs/gsum/bin/fairseq-train", line 33, in sys.exit(load_entry_point('fairseq', 'console_scripts', 'fairseq-train')()) File "/home/xjw/code/guided_summarization/src/fairseq/fairseq_cli/train.py", line 320, in cli_main main(args) File "/home/xjw/code/guided_summarization/src/fairseq/fairseq_cli/train.py", line 81, in main extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer) File "/home/xjw/code/guided_summarization/src/fairseq/fairseq/checkpoint_utils.py", line 134, in load_checkpoint reset_meters=args.reset_meters, File "/home/xjw/code/guided_summarization/src/fairseq/fairseq/trainer.py", line 199, in load_checkpoint "please ensure that the architectures match.".format(filename) Exception: Cannot load model parameters from checkpoint ./bart/bart.large/model.pt; please ensure that the architectures match.

And for the issue #32, even when I remove --max-sentences 1 like you said, the ZeroDivisionError still exists if I want to train using multi-GPU. Only if I use one GPU, the error disappeared but it's too slow to train such a big model.

Thanks for your kindly help :) @zdou0830

zdou0830 commented 2 years ago

Hello, just to confirm, did you change the fairseq/trainer.py file? I tried to handle the mismatch issue here: https://github.com/neulab/guided_summarization/blob/ea4bbe91f189cdb51f7f6a827210f9adc5319b3c/bart/fairseq/trainer.py#L173-L207.

JaniceXiong commented 2 years ago

Thanks! The code handling mismatch (line 185-207) is lost in my trainer.py. After I fix it, the training program goes well.