alibaba / Pai-Megatron-Patch

The official repo of Pai-Megatron-Patch for LLM & VLM large scale training developed by Alibaba Cloud.
Apache License 2.0
674 stars 94 forks source link

转换权重的问题 #280

Closed Jayce1kk closed 2 months ago

Jayce1kk commented 3 months ago

问题

转换完权重之后进行评估验证时出现下述问题

> number of parameters on (tensor, pipeline) model parallel rank (0, 0): 630167424
 loading release checkpoint from /raid/LLM_train/Pai-Megatron-Patch/checkpoints/qwen-ckpts/Qwen2-0.5B-hf-mcore-te-tp1-pp1
[rank0]: Traceback (most recent call last):
[rank0]:   File "/raid/LLM_train/Pai-Megatron-Patch/examples/qwen2/evaluate_mcore_qwen.py", line 207, in <module>
[rank0]:     main()
[rank0]:   File "/raid/LLM_train/Pai-Megatron-Patch/examples/qwen2/evaluate_mcore_qwen.py", line 193, in main
[rank0]:     load_checkpoint(model, None, None)
[rank0]:   File "/raid/LLM_train/Pai-Megatron-Patch/Megatron-LM-240612/megatron/training/checkpointing.py", line 807, in load_checkpoint
[rank0]:     model[0].load_state_dict(state_dict['model'], strict=strict)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 2189, in load_state_dict
[rank0]:     raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
[rank0]: RuntimeError: Error(s) in loading state_dict for GPTModel:
[rank0]:        Missing key(s) in state_dict: "decoder.layers.0.self_attention.linear_proj._extra_state", "decoder.layers.0.self_attention.linear_qkv._extra_state", "decoder.layers.0.mlp.linear_fc1._extra_state", "decoder.layers.0.mlp.linear_fc2._extra_state", "decoder.layers.1.self_attention.linear_proj._extra_state", 

....

"decoder.layers.21.mlp.linear_fc2._extra_state", "decoder.layers.22.self_attention.linear_proj._extra_state", "decoder.layers.22.self_attention.linear_qkv._extra_state", "decoder.layers.22.mlp.linear_fc1._extra_state", "decoder.layers.22.mlp.linear_fc2._extra_state", "decoder.layers.23.self_attention.linear_proj._extra_state", "decoder.layers.23.self_attention.linear_qkv._extra_state", "decoder.layers.23.mlp.linear_fc1._extra_state", "decoder.layers.23.mlp.linear_fc2._extra_state". 
E0703 10:00:55.575000 140162604175808 torch/distributed/elastic/multiprocessing/api.py:826] failed (exitcode: 1) local_rank: 0 (pid: 1450033) of binary: /usr/bin/python

转换权重指令

sh hf2mcore_qwen2_convertor.sh \
0.5B \
/raid/LLM_train/Pai-Megatron-Patch/checkpoints/qwen-ckpts/Qwen2-0.5B \
/raid/LLM_train/Pai-Megatron-Patch/checkpoints/qwen-ckpts/Qwen2-0.5B-hf-mcore-te-tp1-pp1  \
1  \
1  \
1 \
fp32 \
true \
false 

评估指令

sh run_evaluate_mcore_qwen.sh \
0.5B \
1 \
256 \
256 \
bf16 \
1 \
1 \
sel \
true \
false \
false \
true \
/raid/LLM_train/Pai-Megatron-Patch/qwen-datasets/alpaca_zh-qwen-valid.json \
/raid/LLM_train/Pai-Megatron-Patch/checkpoints/qwen-ckpts/Qwen2-0.5B-hf-mcore-te-tp1-pp1

详细报错

torchrun --nproc_per_node 1 --nnodes 1 --node_rank 0 --master_addr localhost --master_port 19751 evaluate_mcore_qwen.py --valid-data-path /raid/LLM_train/Pai-Megatron-Patch/qwen-datasets/alpaca_zh-qwen-valid.json --micro-batch-size 1 --num-layers 24 --hidden-size 896 --num-attention-heads 14 --seq-length 256 --max-position-embeddings 131072 --ffn-hidden-size 4864 --log-interval 1 --eval-interval 100 --eval-iters 10 --tensor-model-parallel-size 1 --pipeline-model-parallel-size 1 --no-load-optim --no-load-rng --seed 1234 --num-workers 0 --max-padding-length 256 --extra-vocab-size 293 --patch-tokenizer-type LLamaTokenizer --dataset LLama-Pretrain-Raw --swiglu --normalization RMSNorm --norm-epsilon 1e-6 --use-rotary-position-embeddings --no-rope-fusion --position-embedding-type rope --rotary-base 1000000 --untie-embeddings-and-output-weights --disable-bias-linear --add-qkv-bias --group-query-attention --num-query-groups 2 --eod-mask-loss --bf16 --load /raid/LLM_train/Pai-Megatron-Patch/checkpoints/qwen-ckpts/Qwen2-0.5B-hf-mcore-te-tp1-pp1 --transformer-impl transformer_engine --recompute-activations --use-distributed-optimizer
INFO:datasets:PyTorch version 2.4.0a0+07cecf4168.nv24.5 available.
in oss file
/raid/LLM_train/Pai-Megatron-Patch/megatron_patch/model/llava/clip_encoder.py:26: UserWarning: The cvcuda environment does not exist. Install cvcuda and use it
  warnings.warn("The cvcuda environment does not exist. Install cvcuda and use it")
using world size: 1, data-parallel size: 1, context-parallel size: 1 tensor-model-parallel size: 1, pipeline-model-parallel size: 1 
setting global batch size to 1
WARNING: Setting args.overlap_p2p_comm to False since non-interleaved schedule does not support overlapping p2p communication
accumulate and all-reduce gradients in fp32 for bfloat16 data type.
using torch.bfloat16 for parameters ...
------------------------ arguments ------------------------
  accumulate_allreduce_grads_in_fp32 .............. True
  adam_beta1 ...................................... 0.9
  adam_beta2 ...................................... 0.999
  adam_eps ........................................ 1e-08
  adaptive_seq_len ................................ False
  add_bias_attn_fc ................................ True
  add_bias_linear ................................. False
  add_bias_linear_fc .............................. True
  add_position_embedding .......................... True
  add_qkv_bias .................................... True
  adlr_autoresume ................................. False
  adlr_autoresume_interval ........................ 1000
  apply_layernorm_1p .............................. False
  apply_query_key_layer_scaling ................... False
  apply_residual_connection_post_layernorm ........ False
  apply_rope_fusion ............................... False
  async_save ...................................... None
  async_tensor_model_parallel_allreduce ........... True
  attention_dropout ............................... 0.1
  attention_head_type ............................. None
  attention_softmax_in_fp32 ....................... False
  auto_detect_ckpt_format ......................... False
  barrier_with_L1_time ............................ True
  bert_binary_head ................................ True
  bert_embedder_type .............................. megatron
  bert_load ....................................... None
  bf16 ............................................ True
  bias_dropout_fusion ............................. True
  bias_gelu_fusion ................................ False
  bias_swiglu_fusion .............................. True
  biencoder_projection_dim ........................ 0
  biencoder_shared_query_context_model ............ False
  block_data_path ................................. None
  calculate_per_token_loss ........................ False
  check_for_nan_in_loss_and_grad .................. True
  check_weight_hash_across_dp_replicas_interval ... None
  ckpt_assume_constant_structure .................. False
  ckpt_fully_parallel_load ........................ False
  ckpt_fully_parallel_save ........................ False
  ckpt_step ....................................... None
  classes_fraction ................................ 1.0
  clip_grad ....................................... 1.0
  clone_scatter_output_in_embedding ............... True
  consumed_train_samples .......................... 0
  consumed_valid_samples .......................... 0
  context_parallel_size ........................... 1
  convert_checkpoint_from_megatron_to_transformers  False
  create_attention_mask_in_dataloader ............. True
  cvcuda_image_processing ......................... False
  data_cache_path ................................. None
  data_dir ........................................ None
  data_parallel_random_init ....................... False
  data_parallel_size .............................. 1
  data_path ....................................... None
  data_per_class_fraction ......................... 1.0
  data_sharding ................................... True
  dataloader_type ................................. single
  dataset ......................................... LLama-Pretrain-Raw
  ddp_average_in_collective ....................... False
  ddp_bucket_size ................................. None
  decoder_num_layers .............................. None
  decoder_seq_length .............................. None
  decoupled_lr .................................... None
  decoupled_min_lr ................................ None
  delay_grad_reduce ............................... True
  delay_param_gather .............................. False
  deprecated_use_mcore_models ..................... False
  deterministic_mode .............................. False
  dino_bottleneck_size ............................ 256
  dino_freeze_last_layer .......................... 1
  dino_head_hidden_size ........................... 2048
  dino_local_crops_number ......................... 10
  dino_local_img_size ............................. 96
  dino_norm_last_layer ............................ False
  dino_teacher_temp ............................... 0.07
  dino_warmup_teacher_temp ........................ 0.04
  dino_warmup_teacher_temp_epochs ................. 30
  disable_straggler_on_startup .................... False
  dist_ckpt_format ................................ torch_dist
  distribute_saved_activations .................... False
  distributed_backend ............................. nccl
  distributed_timeout_minutes ..................... 10
  embed_layernorm ................................. False
  embedding_path .................................. None
  empty_unused_memory_level ....................... 0
  enable_one_logger ............................... False
  enable_parallel_output .......................... True
  enable_shared_expert ............................ False
  encoder_num_layers .............................. 24
  encoder_seq_length .............................. 256
  end_weight_decay ................................ 0.01
  eod_mask_loss ................................... True
  epochs .......................................... None
  eval_dev ........................................ False
  eval_fp32 ....................................... False
  eval_interval ................................... 100
  eval_iters ...................................... 10
  evidence_data_path .............................. None
  exit_duration_in_mins ........................... None
  exit_interval ................................... None
  exit_on_missing_checkpoint ...................... False
  exit_signal_handler ............................. False
  expert_interval ................................. 2
  expert_model_parallel_size ...................... 1
  expert_tensor_parallelism ....................... False
  extra_vocab_size ................................ 293
  ffn_hidden_size ................................. 4864
  finetune ........................................ False
  fp16 ............................................ False
  fp16_lm_cross_entropy ........................... False
  fp32_residual_connection ........................ False
  fp8 ............................................. None
  fp8_amax_compute_algo ........................... most_recent
  fp8_amax_history_len ............................ 1
  fp8_interval .................................... 1
  fp8_margin ...................................... 0
  fp8_wgrad ....................................... True
  freeze_clip_vision_tower ........................ False
  freeze_llm ...................................... False
  generation_length ............................... None
  global_batch_size ............................... 1
  glu_activation .................................. None
  gradient_accumulation_fusion .................... True
  group_query_attention ........................... True
  head_lr_mult .................................... 1.0
  hidden_dropout .................................. 0.1
  hidden_size ..................................... 896
  hysteresis ...................................... 2
  ict_head_size ................................... None
  ict_load ........................................ None
  image_aspect_ratio .............................. square
  image_folder .................................... 
  image_size ...................................... None
  img_h ........................................... 224
  img_w ........................................... 224
  indexer_batch_size .............................. 128
  indexer_log_interval ............................ 1000
  inference_batch_times_seqlen_threshold .......... 512
  init_method_std ................................. 0.02
  init_method_xavier_uniform ...................... False
  initial_loss_scale .............................. 4294967296
  input_len ....................................... 1
  intermediate_size ............................... None
  iter_per_epoch .................................. 1250
  keep_last ....................................... False
  kv_channels ..................................... 64
  kv_lora_rank .................................... None
  lazy_mpu_init ................................... None
  load ............................................ /raid/LLM_train/Pai-Megatron-Patch/checkpoints/qwen-ckpts/Qwen2-0.5B-hf-mcore-te-tp1-pp1
  local_rank ...................................... None
  log_batch_size_to_tensorboard ................... False
  log_interval .................................... 1
  log_learning_rate_to_tensorboard ................ True
  log_loss_scale_to_tensorboard ................... True
  log_memory_to_tensorboard ....................... False
  log_num_zeros_in_grad ........................... False
  log_params_norm ................................. False
  log_progress .................................... False
  log_straggler ................................... False
  log_throughput .................................. False
  log_timers_to_tensorboard ....................... False
  log_validation_ppl_to_tensorboard ............... False
  log_world_size_to_tensorboard ................... False
  logging_level ................................... None
  loss_scale ...................................... None
  loss_scale_window ............................... 1000
  lr .............................................. None
  lr_decay_iters .................................. None
  lr_decay_samples ................................ None
  lr_decay_style .................................. linear
  lr_warmup_fraction .............................. None
  lr_warmup_init .................................. 0.0
  lr_warmup_iters ................................. 0
  lr_warmup_samples ............................... 0
  make_vocab_size_divisible_by .................... 128
  manual_gc ....................................... False
  manual_gc_eval .................................. True
  manual_gc_interval .............................. 0
  mask_factor ..................................... 1.0
  mask_prob ....................................... 0.15
  mask_type ....................................... random
  masked_softmax_fusion ........................... True
  max_padding_length .............................. 256
  max_position_embeddings ......................... 131072
  max_tokens_to_oom ............................... 12000
  merge_file ...................................... None
  micro_batch_size ................................ 1
  min_loss_scale .................................. 1.0
  min_lr .......................................... 0.0
  mm_projector_type ............................... None
  mm_use_im_patch_token ........................... False
  mm_use_im_start_end ............................. False
  mm_vision_select_layer .......................... None
  mmap_bin_files .................................. True
  mock_data ....................................... False
  moe ............................................. False
  moe_aux_loss_coeff .............................. 0.0
  moe_eval_capacity_factor ........................ 1.0
  moe_expert_capacity_factor ...................... None
  moe_expert_parallel_size ........................ None
  moe_extended_tp ................................. False
  moe_ffn_hidden_size ............................. None
  moe_grouped_gemm ................................ False
  moe_input_feature_slicing ....................... False
  moe_input_jitter_eps ............................ None
  moe_layer_freq .................................. 1
  moe_layer_recompute ............................. False
  moe_loss_coeff .................................. 0.01
  moe_min_capacity ................................ 4
  moe_pad_expert_input_to_capacity ................ False
  moe_per_layer_logging ........................... False
  moe_router_load_balancing_type .................. aux_loss
  moe_router_topk ................................. 2
  moe_token_dispatcher_type ....................... allgather
  moe_token_drop_policy ........................... probs
  moe_topk ........................................ 1
  moe_train_capacity_factor ....................... 1.0
  moe_z_loss_coeff ................................ None
  n_head_kv ....................................... None
  nccl_communicator_config_path ................... None
  no_load_optim ................................... True
  no_load_rng ..................................... True
  no_persist_layer_norm ........................... False
  no_save_optim ................................... None
  no_save_rng ..................................... None
  norm_epsilon .................................... 1e-06
  normalization ................................... RMSNorm
  num_attention_heads ............................. 14
  num_channels .................................... 3
  num_classes ..................................... 1000
  num_dataset_builder_threads ..................... 1
  num_experts ..................................... None
  num_fewshot ..................................... None
  num_layers ...................................... 24
  num_layers_per_virtual_pipeline_stage ........... None
  num_query_groups ................................ 2
  num_shared_experts .............................. None
  num_workers ..................................... 0
  one_logger_entity ............................... hwinf_dcm
  one_logger_project .............................. e2e-tracking
  one_logger_run_name ............................. None
  onnx_safe ....................................... None
  openai_gelu ..................................... False
  optimizer ....................................... adam
  out_seq_length .................................. 1024
  output_bert_embeddings .......................... False
  overlap_grad_reduce ............................. False
  overlap_p2p_comm ................................ False
  overlap_param_gather ............................ False
  override_opt_param_scheduler .................... False
  params_dtype .................................... torch.bfloat16
  patch_dim ....................................... 16
  patch_size ...................................... None
  patch_tokenizer_type ............................ LLamaTokenizer
  perform_initialization .......................... True
  pipeline_model_parallel_size .................... 1
  pipeline_model_parallel_split_rank .............. None
  position_embedding_type ......................... rope
  position_encoding_2d ............................ False
  pretrained_checkpoint ........................... None
  profile ......................................... False
  profile_ranks ................................... [0]
  profile_step_end ................................ 12
  profile_step_start .............................. 10
  q_lora_rank ..................................... None
  qk_layernorm .................................... False
  qk_nope_head_dim ................................ None
  qk_rope_head_dim ................................ None
  query_in_block_prob ............................. 0.1
  rampup_batch_size ............................... None
  rank ............................................ 0
  recompute_granularity ........................... selective
  recompute_method ................................ None
  recompute_num_layers ............................ None
  repetition_penalty .............................. 1.1
  reset_attention_mask ............................ False
  reset_position_ids .............................. False
  retriever_report_topk_accuracies ................ []
  retriever_score_scaling ......................... False
  retriever_seq_length ............................ 256
  retro_add_retriever ............................. False
  retro_attention_gate ............................ 1
  retro_cyclic_train_iters ........................ None
  retro_encoder_attention_dropout ................. 0.1
  retro_encoder_hidden_dropout .................... 0.1
  retro_encoder_layers ............................ 2
  retro_num_neighbors ............................. 2
  retro_num_retrieved_chunks ...................... 2
  retro_project_dir ............................... None
  retro_verify_neighbor_count ..................... True
  rotary_base ..................................... 1000000
  rotary_interleaved .............................. False
  rotary_percent .................................. 1.0
  rotary_scale_factor ............................. 1
  rotary_scaling_factor ........................... 1
  rotary_seq_len_interpolation_factor ............. None
  router_type ..................................... topk
  sample_rate ..................................... 1.0
  save ............................................ None
  save_interval ................................... None
  scatter_gather_tensors_in_pipeline .............. True
  seed ............................................ 1234
  seq_length ...................................... 256
  sequence_parallel ............................... False
  sgd_momentum .................................... 0.9
  shared_moe_ffn_hidden_size ...................... None
  short_seq_prob .................................. 0.1
  skip_train ...................................... False
  sliding_window .................................. None
  source_seq_len .................................. None
  spec ............................................ None
  split ........................................... None
  squared_relu .................................... False
  standalone_embedding_stage ...................... False
  start_weight_decay .............................. 0.01
  straggler_ctrlr_port ............................ 65535
  straggler_minmax_count .......................... 1
  swiglu .......................................... True
  swin_backbone_type .............................. tiny
  target_seq_len .................................. None
  task_list ....................................... all
  temperature ..................................... 1.0
  tensor_model_parallel_size ...................... 1
  tensorboard_dir ................................. None
  tensorboard_log_interval ........................ 1
  tensorboard_queue_size .......................... 1000
  test_data_path .................................. None
  test_mode ....................................... False
  text_generate_gt_file ........................... 
  text_generate_input_file ........................ 
  text_generate_output_file ....................... 
  time ............................................ False
  timing_log_level ................................ 0
  timing_log_option ............................... minmax
  titles_data_path ................................ None
  tokenizer_model ................................. None
  tokenizer_type .................................. NullTokenizer
  top_k ........................................... 0
  top_p ........................................... 0.0
  tp_comm_bulk_dgrad .............................. True
  tp_comm_bulk_wgrad .............................. True
  tp_comm_overlap ................................. False
  tp_comm_overlap_ag .............................. True
  tp_comm_overlap_cfg ............................. None
  tp_comm_overlap_rs .............................. True
  tp_comm_overlap_rs_dgrad ........................ False
  tp_comm_split_ag ................................ True
  tp_comm_split_rs ................................ True
  train_data ...................................... None
  train_data_path ................................. None
  train_iters ..................................... None
  train_samples ................................... None
  transformer_impl ................................ transformer_engine
  transformer_pipeline_model_parallel_size ........ 1
  transformer_timers .............................. False
  transformer_type ................................ megatron
  tune_mm_mlp_adapter ............................. False
  untie_embeddings_and_output_weights ............. True
  use_alibi_mask .................................. False
  use_checkpoint_args ............................. False
  use_checkpoint_opt_param_scheduler .............. False
  use_cpu_initialization .......................... None
  use_dist_ckpt ................................... False
  use_distributed_optimizer ....................... True
  use_flash_attn .................................. False
  use_legacy_models ............................... False
  use_llama2_rotary_position_embeddings ........... False
  use_mistral_rotary_position_embeddings .......... False
  use_normhead .................................... False
  use_one_sent_docs ............................... False
  use_ring_exchange_p2p ........................... False
  use_rotary_position_embeddings .................. True
  use_tp_pp_dp_mapping ............................ False
  use_tutel ....................................... False
  v_head_dim ...................................... None
  valid_data ...................................... None
  valid_data_path ................................. ['/raid/LLM_train/Pai-Megatron-Patch/qwen-datasets/alpaca_zh-qwen-valid.json']
  variable_seq_lengths ............................ False
  verbosity ....................................... INFO
  version ......................................... plain
  virtual_pipeline_model_parallel_size ............ None
  vision_backbone_type ............................ vit
  vision_pretraining .............................. False
  vision_pretraining_type ......................... classify
  vision_tower .................................... 
  vocab_extra_ids ................................. 0
  vocab_file ...................................... None
  vocab_size ...................................... -1
  wandb_exp_name .................................. 
  wandb_project ................................... 
  wandb_save_dir .................................. 
  weight_decay .................................... 0.01
  weight_decay_incr_style ......................... constant
  world_size ...................................... 1
  yaml_cfg ........................................ None
  z_loss_weight ................................... 0.0
-------------------- end of arguments ---------------------
setting number of micro-batches to constant 1
> building NullTokenizer tokenizer ...
 > padded vocab (size: 0) with 0 dummy tokens (new size: 0)
> initializing torch distributed ...
> initialized tensor model parallel with size 1
> initialized pipeline model parallel with size 1
> setting random seeds to 1234 ...
> compiling dataset index builder ...
make: Entering directory '/raid/LLM_train/Pai-Megatron-Patch/Megatron-LM-240612/megatron/core/datasets'
make: Nothing to be done for 'default'.
make: Leaving directory '/raid/LLM_train/Pai-Megatron-Patch/Megatron-LM-240612/megatron/core/datasets'
>>> done with dataset index builder. Compilation time: 0.500 seconds
WARNING: constraints for invoking optimized fused softmax kernel are not met. We default back to unfused kernel invocations.
> compiling and loading fused kernels ...
>>> done with compiling and loading fused kernels. Compilation time: 0.695 seconds
> building LLamaTokenizer tokenizer ...
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Running Encoding (num_proc=16): 100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [01:08<00:00, 14.60 examples/s]
1000it [00:00, 130834.86it/s]
  >> total number of samples: 997
> building LLamaTokenizer tokenizer ...
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
building Qwen2 model ...
 > number of parameters on (tensor, pipeline) model parallel rank (0, 0): 630167424
 loading release checkpoint from /raid/LLM_train/Pai-Megatron-Patch/checkpoints/qwen-ckpts/Qwen2-0.5B-hf-mcore-te-tp1-pp1
[rank0]: Traceback (most recent call last):
[rank0]:   File "/raid/LLM_train/Pai-Megatron-Patch/examples/qwen2/evaluate_mcore_qwen.py", line 207, in <module>
[rank0]:     main()
[rank0]:   File "/raid/LLM_train/Pai-Megatron-Patch/examples/qwen2/evaluate_mcore_qwen.py", line 193, in main
[rank0]:     load_checkpoint(model, None, None)
[rank0]:   File "/raid/LLM_train/Pai-Megatron-Patch/Megatron-LM-240612/megatron/training/checkpointing.py", line 807, in load_checkpoint
[rank0]:     model[0].load_state_dict(state_dict['model'], strict=strict)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 2189, in load_state_dict
[rank0]:     raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
[rank0]: RuntimeError: Error(s) in loading state_dict for GPTModel:
[rank0]:        Missing key(s) in state_dict: "decoder.layers.0.self_attention.linear_proj._extra_state", "decoder.layers.0.self_attention.linear_qkv._extra_state", "decoder.layers.0.mlp.linear_fc1._extra_state", "decoder.layers.0.mlp.linear_fc2._extra_state", "decoder.layers.1.self_attention.linear_proj._extra_state", "decoder.layers.1.self_attention.linear_qkv._extra_state", "decoder.layers.1.mlp.linear_fc1._extra_state", "decoder.layers.1.mlp.linear_fc2._extra_state", "decoder.layers.2.self_attention.linear_proj._extra_state", "decoder.layers.2.self_attention.linear_qkv._extra_state", "decoder.layers.2.mlp.linear_fc1._extra_state", "decoder.layers.2.mlp.linear_fc2._extra_state", "decoder.layers.3.self_attention.linear_proj._extra_state", "decoder.layers.3.self_attention.linear_qkv._extra_state", "decoder.layers.3.mlp.linear_fc1._extra_state", "decoder.layers.3.mlp.linear_fc2._extra_state", "decoder.layers.4.self_attention.linear_proj._extra_state", "decoder.layers.4.self_attention.linear_qkv._extra_state", "decoder.layers.4.mlp.linear_fc1._extra_state", "decoder.layers.4.mlp.linear_fc2._extra_state", "decoder.layers.5.self_attention.linear_proj._extra_state", "decoder.layers.5.self_attention.linear_qkv._extra_state", "decoder.layers.5.mlp.linear_fc1._extra_state", "decoder.layers.5.mlp.linear_fc2._extra_state", "decoder.layers.6.self_attention.linear_proj._extra_state", "decoder.layers.6.self_attention.linear_qkv._extra_state", "decoder.layers.6.mlp.linear_fc1._extra_state", "decoder.layers.6.mlp.linear_fc2._extra_state", "decoder.layers.7.self_attention.linear_proj._extra_state", "decoder.layers.7.self_attention.linear_qkv._extra_state", "decoder.layers.7.mlp.linear_fc1._extra_state", "decoder.layers.7.mlp.linear_fc2._extra_state", "decoder.layers.8.self_attention.linear_proj._extra_state", "decoder.layers.8.self_attention.linear_qkv._extra_state", "decoder.layers.8.mlp.linear_fc1._extra_state", "decoder.layers.8.mlp.linear_fc2._extra_state", "decoder.layers.9.self_attention.linear_proj._extra_state", "decoder.layers.9.self_attention.linear_qkv._extra_state", "decoder.layers.9.mlp.linear_fc1._extra_state", "decoder.layers.9.mlp.linear_fc2._extra_state", "decoder.layers.10.self_attention.linear_proj._extra_state", "decoder.layers.10.self_attention.linear_qkv._extra_state", "decoder.layers.10.mlp.linear_fc1._extra_state", "decoder.layers.10.mlp.linear_fc2._extra_state", "decoder.layers.11.self_attention.linear_proj._extra_state", "decoder.layers.11.self_attention.linear_qkv._extra_state", "decoder.layers.11.mlp.linear_fc1._extra_state", "decoder.layers.11.mlp.linear_fc2._extra_state", "decoder.layers.12.self_attention.linear_proj._extra_state", "decoder.layers.12.self_attention.linear_qkv._extra_state", "decoder.layers.12.mlp.linear_fc1._extra_state", "decoder.layers.12.mlp.linear_fc2._extra_state", "decoder.layers.13.self_attention.linear_proj._extra_state", "decoder.layers.13.self_attention.linear_qkv._extra_state", "decoder.layers.13.mlp.linear_fc1._extra_state", "decoder.layers.13.mlp.linear_fc2._extra_state", "decoder.layers.14.self_attention.linear_proj._extra_state", "decoder.layers.14.self_attention.linear_qkv._extra_state", "decoder.layers.14.mlp.linear_fc1._extra_state", "decoder.layers.14.mlp.linear_fc2._extra_state", "decoder.layers.15.self_attention.linear_proj._extra_state", "decoder.layers.15.self_attention.linear_qkv._extra_state", "decoder.layers.15.mlp.linear_fc1._extra_state", "decoder.layers.15.mlp.linear_fc2._extra_state", "decoder.layers.16.self_attention.linear_proj._extra_state", "decoder.layers.16.self_attention.linear_qkv._extra_state", "decoder.layers.16.mlp.linear_fc1._extra_state", "decoder.layers.16.mlp.linear_fc2._extra_state", "decoder.layers.17.self_attention.linear_proj._extra_state", "decoder.layers.17.self_attention.linear_qkv._extra_state", "decoder.layers.17.mlp.linear_fc1._extra_state", "decoder.layers.17.mlp.linear_fc2._extra_state", "decoder.layers.18.self_attention.linear_proj._extra_state", "decoder.layers.18.self_attention.linear_qkv._extra_state", "decoder.layers.18.mlp.linear_fc1._extra_state", "decoder.layers.18.mlp.linear_fc2._extra_state", "decoder.layers.19.self_attention.linear_proj._extra_state", "decoder.layers.19.self_attention.linear_qkv._extra_state", "decoder.layers.19.mlp.linear_fc1._extra_state", "decoder.layers.19.mlp.linear_fc2._extra_state", "decoder.layers.20.self_attention.linear_proj._extra_state", "decoder.layers.20.self_attention.linear_qkv._extra_state", "decoder.layers.20.mlp.linear_fc1._extra_state", "decoder.layers.20.mlp.linear_fc2._extra_state", "decoder.layers.21.self_attention.linear_proj._extra_state", "decoder.layers.21.self_attention.linear_qkv._extra_state", "decoder.layers.21.mlp.linear_fc1._extra_state", "decoder.layers.21.mlp.linear_fc2._extra_state", "decoder.layers.22.self_attention.linear_proj._extra_state", "decoder.layers.22.self_attention.linear_qkv._extra_state", "decoder.layers.22.mlp.linear_fc1._extra_state", "decoder.layers.22.mlp.linear_fc2._extra_state", "decoder.layers.23.self_attention.linear_proj._extra_state", "decoder.layers.23.self_attention.linear_qkv._extra_state", "decoder.layers.23.mlp.linear_fc1._extra_state", "decoder.layers.23.mlp.linear_fc2._extra_state". 
E0703 10:00:55.575000 140162604175808 torch/distributed/elastic/multiprocessing/api.py:826] failed (exitcode: 1) local_rank: 0 (pid: 1450033) of binary: /usr/bin/python
Traceback (most recent call last):
  File "/usr/local/bin/torchrun", line 33, in <module>
    sys.exit(load_entry_point('torch==2.4.0a0+07cecf4168.nv24.5', 'console_scripts', 'torchrun')())
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 347, in wrapper
    return f(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/run.py", line 879, in main
    run(args)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/run.py", line 870, in run
    elastic_launch(
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/launcher/api.py", line 132, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/launcher/api.py", line 263, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
============================================================
evaluate_mcore_qwen.py FAILED
------------------------------------------------------------
Failures:
  <NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2024-07-03_10:00:55
  host      : 6a459af124ab
  rank      : 0 (local_rank: 0)
  exitcode  : 1 (pid: 1450033)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
============================================================
Jayce1kk commented 3 months ago

在转换权重的代码中将 /Pai-Megatron-Patch/toolkits/model_checkpoints_convertor/qwen/hf2mcore_qwen2_dense_and_moe_gqa.py 中line482行

if full_model[k] is None or "_extra_state" in k:
            full_model.pop(k)

修改为下述代码之后不报错,但是不知道对后续的训练有没有影响

if full_model[k] is None :
            full_model.pop(k)
jerryli1981 commented 2 months ago

您好,报错中仅仅出现extra_state其实不是一个错误,仅需要将strict=False即可

Jayce1kk commented 2 months ago

好的,非常感谢您的回复