NVIDIA / Megatron-LM

Ongoing research training transformer models at scale
https://docs.nvidia.com/megatron-core/developer-guide/latest/user-guide/index.html#quick-start
Other
10.03k stars 2.26k forks source link

[BUG]Environment: Megatron0.5.0+TE1.4; Issue: I began training a model without using the --use-mcore-models option for model pretraining. Later, I needed to use the CP function for fine-tuning longer sequences, thus I had to enable --use-mcore-models. However, I discovered that I couldn't load the previous pre-trained model and encountered an error. #793

Closed liangshaopeng closed 4 months ago

liangshaopeng commented 4 months ago

Describe the bug Environment: Megatron0.5.0+TE1.4; Issue: I began training a model without using the --use-mcore-models option for model pretraining. Later, I needed to use the CP function for fine-tuning longer sequences, thus I had to enable --use-mcore-models. However, I discovered that I couldn't load the previous pre-trained model and encountered an error.

To Reproduce step 1, I trained a model architecture with Qwen1.8B, and the parameters were initialized randomly. The specific Megatron parameters are as follows, and it’s worth noting that I did not use --use-mcore-models. arguments like this: ------------------------ arguments ------------------------ accumulate_allreduce_grads_in_fp32 .............. True adam_beta1 ...................................... 0.9 adam_beta2 ...................................... 0.95 adam_eps ........................................ 1e-08 add_bias_linear ................................. False add_more_sp_tokens .............................. True add_position_embedding .......................... True add_qkv_bias .................................... True adlr_autoresume ................................. False adlr_autoresume_interval ........................ 1000 apply_layernorm_1p .............................. False apply_query_key_layer_scaling ................... False apply_residual_connection_post_layernorm ........ False apply_rope_fusion ............................... True async_tensor_model_parallel_allreduce ........... True attention_dropout ............................... 0.0 attention_softmax_in_fp32 ....................... False barrier_with_L1_time ............................ True bert_binary_head ................................ True bert_embedder_type .............................. megatron bert_load ....................................... None bf16 ............................................ True bias_dropout_fusion ............................. True bias_gelu_fusion ................................ False bias_swiglu_fusion .............................. True biencoder_projection_dim ........................ 0 biencoder_shared_query_context_model ............ False block_data_path ................................. None check_for_nan_in_loss_and_grad .................. False classes_fraction ................................ 1.0 clip_grad ....................................... 1.0 clone_scatter_output_in_embedding ............... True consumed_train_samples .......................... 0 consumed_valid_samples .......................... 0 context_parallel_size ........................... 1 data_cache_path ................................. None data_parallel_random_init ....................... False data_parallel_size .............................. 1 data_path ....................................... ['/data/oss_bucket_0/eleme_corpus/sft/general/longtext/EGPT-72B-longseq/v9/_chatml_document'] data_per_class_fraction ......................... 1.0 data_sharding ................................... True dataloader_type ................................. cyclic decoder_num_layers .............................. None decoder_seq_length .............................. None delay_grad_reduce ............................... True delay_param_gather .............................. False dino_bottleneck_size ............................ 256 dino_freeze_last_layer .......................... 1 dino_head_hidden_size ........................... 2048 dino_local_crops_number ......................... 10 dino_local_img_size ............................. 96 dino_norm_last_layer ............................ False dino_teacher_temp ............................... 0.07 dino_warmup_teacher_temp ........................ 0.04 dino_warmup_teacher_temp_epochs ................. 30 distribute_saved_activations .................... False distributed_backend ............................. nccl distributed_timeout_minutes ..................... 10 embedding_path .................................. None empty_unused_memory_level ....................... 0 enable_one_logger ............................... False encoder_num_layers .............................. 24 encoder_seq_length .............................. 2048 end_weight_decay ................................ 0.1 eod_mask_loss ................................... False eval_interval ................................... 10000 eval_iters ...................................... 10 evidence_data_path .............................. None exit_duration_in_mins ........................... None exit_interval ................................... None exit_on_missing_checkpoint ...................... False exit_signal_handler ............................. False expert_model_parallel_size ...................... 1 ffn_hidden_size ................................. 5504 finetune ........................................ True fp16 ............................................ False fp16_lm_cross_entropy ........................... False fp32_residual_connection ........................ False fp8 ............................................. None fp8_amax_compute_algo ........................... most_recent fp8_amax_history_len ............................ 1 fp8_interval .................................... 1 fp8_margin ...................................... 0 fp8_wgrad ....................................... True global_batch_size ............................... 8 gradient_accumulation_fusion .................... False group_query_attention ........................... False head_lr_mult .................................... 1.0 hidden_dropout .................................. 0.0 hidden_size ..................................... 2048 hysteresis ...................................... 2 ict_head_size ................................... None ict_load ........................................ None img_h ........................................... 224 img_w ........................................... 224 indexer_batch_size .............................. 128 indexer_log_interval ............................ 1000 inference_batch_times_seqlen_threshold .......... 512 init_method_std ................................. 0.01 init_method_xavier_uniform ...................... False initial_loss_scale .............................. 4294967296 iter_per_epoch .................................. 1250 kv_channels ..................................... 128 lazy_mpu_init ................................... None load ............................................ None local_rank ...................................... None log_batch_size_to_tensorboard ................... True log_interval .................................... 10 log_learning_rate_to_tensorboard ................ True log_loss_scale_to_tensorboard ................... True log_memory_to_tensorboard ....................... False log_num_zeros_in_grad ........................... False log_params_norm ................................. False log_progress .................................... False log_throughput .................................. False log_timers_to_tensorboard ....................... True log_validation_ppl_to_tensorboard ............... True log_world_size_to_tensorboard ................... False loss_scale ...................................... None loss_scale_window ............................... 1000 lr .............................................. 0.0003 lr_decay_iters .................................. None lr_decay_samples ................................ 1000 lr_decay_style .................................. cosine lr_warmup_fraction .............................. None lr_warmup_init .................................. 0.0 lr_warmup_iters ................................. 0 lr_warmup_samples ............................... 100 make_vocab_size_divisible_by .................... 128 manual_gc ....................................... False manual_gc_eval .................................. True manual_gc_interval .............................. 0 mask_factor ..................................... 1.0 mask_prob ....................................... 0.15 mask_type ....................................... random masked_softmax_fusion ........................... True max_position_embeddings ......................... 2048 max_tokens_to_oom ............................... 12000 merge_file ...................................... qwen_15w.tiktoken micro_batch_size ................................ 1 min_loss_scale .................................. 1.0 min_lr .......................................... 3e-05 mock_data ....................................... False moe_aux_loss_coeff .............................. 0.0 moe_grouped_gemm ................................ False moe_input_jitter_eps ............................ None moe_router_load_balancing_type .................. aux_loss moe_router_topk ................................. 2 moe_token_dropping .............................. False moe_z_loss_coeff ................................ None nccl_communicator_config_path ................... None no_load_optim ................................... True no_load_rng ..................................... None no_persist_layer_norm ........................... False no_save_optim ................................... None no_save_rng ..................................... None norm_epsilon .................................... 1e-05 normalization ................................... RMSNorm num_attention_heads ............................. 16 num_channels .................................... 3 num_classes ..................................... 1000 num_experts ..................................... None num_layers ...................................... 24 num_layers_per_virtual_pipeline_stage ........... None num_query_groups ................................ 1 num_workers ..................................... 2 one_logger_entity ............................... hwinf_dcm one_logger_project .............................. e2e-tracking one_logger_run_name ............................. None onnx_safe ....................................... None openai_gelu ..................................... False optimizer ....................................... adam output_bert_embeddings .......................... False overlap_grad_reduce ............................. False overlap_p2p_comm ................................ False overlap_param_gather ............................ False override_opt_param_scheduler .................... False params_dtype .................................... torch.bfloat16 patch_dim ....................................... 16 perform_initialization .......................... True pipeline_model_parallel_size .................... 1 pipeline_model_parallel_split_rank .............. None position_embedding_type ......................... rope profile ......................................... False profile_ranks ................................... [0] profile_step_end ................................ 12 profile_step_start .............................. 10 query_in_block_prob ............................. 0.1 rampup_batch_size ............................... None rank ............................................ 0 recompute_granularity ........................... selective recompute_method ................................ None recompute_num_layers ............................ None reset_attention_mask ............................ False reset_position_ids .............................. False retriever_report_topk_accuracies ................ [] retriever_score_scaling ......................... False retriever_seq_length ............................ 256 retro_add_retriever ............................. False retro_attention_gate ............................ 1 retro_cyclic_train_iters ........................ None retro_encoder_attention_dropout ................. 0.1 retro_encoder_hidden_dropout .................... 0.1 retro_encoder_layers ............................ 2 retro_num_neighbors ............................. 2 retro_num_retrieved_chunks ...................... 2 retro_return_doc_ids ............................ False retro_verify_neighbor_count ..................... True retro_workdir ................................... None rotary_interleaved .............................. False rotary_percent .................................. 1.0 rotary_seq_len_interpolation_factor ............. None sample_rate ..................................... 1.0 save ............................................ /data/oss_bucket_0/LLM/Qwen-1_8B-Chat-Pro-Megatron-train/ save_interval ................................... 10000 scatter_gather_tensors_in_pipeline .............. True seed ............................................ 3407 seq_length ...................................... 2048 sequence_parallel ............................... False sgd_momentum .................................... 0.9 short_seq_prob .................................. 0.1 skip_train ...................................... False spec ............................................ None split ........................................... 98,2,0 squared_relu .................................... False standalone_embedding_stage ...................... False start_weight_decay .............................. 0.1 swiglu .......................................... True swin_backbone_type .............................. tiny tensor_model_parallel_size ...................... 1 tensorboard_dir ................................. None tensorboard_log_interval ........................ 1 tensorboard_queue_size .......................... 1 test_data_path .................................. None timing_log_level ................................ 0 timing_log_option ............................... minmax titles_data_path ................................ None tokenizer_model ................................. None tokenizer_type .................................. QWenTokenizer tp_comm_bulk_dgrad .............................. True tp_comm_bulk_wgrad .............................. True tp_comm_overlap ................................. False tp_comm_overlap_cfg ............................. None tp_comm_split_ag ................................ True tp_comm_split_rs ................................ True train_data_path ................................. None train_iters ..................................... None train_samples ................................... 1000 transformer_impl ................................ local transformer_pipeline_model_parallel_size ........ 1 untie_embeddings_and_output_weights ............. True use_checkpoint_args ............................. False use_checkpoint_opt_param_scheduler .............. False use_cpu_initialization .......................... None use_distributed_optimizer ....................... True use_flash_attn .................................. True use_mcore_models ................................ False use_one_sent_docs ............................... False use_ring_exchange_p2p ........................... False use_rotary_position_embeddings .................. False valid_data_path ................................. None variable_seq_lengths ............................ False virtual_pipeline_model_parallel_size ............ None vision_backbone_type ............................ vit vision_pretraining .............................. False vision_pretraining_type ......................... classify vocab_extra_ids ................................. 0 vocab_file ...................................... qwen_15w.tiktoken vocab_size ...................................... None wandb_exp_name .................................. wandb_project ................................... wandb_save_dir .................................. weight_decay .................................... 0.1 weight_decay_incr_style ......................... constant world_size ...................................... 1 -------------------- end of arguments ---------------------

step 2, my plan was to use CP (context-parallel) fine-tuning to expand the contextLength. Therefore, I had to enable --use-mcore-models. Upon continuing the training, I then found that I was unable to load the model obtained from training in step 1. ERROR: The error seems to be due to a mismatch in the naming convention of the parameters. However, I can't figure out why using --use-mcore-models or not would lead to different parameter naming. Isn't this supposed to be seamlessly compatible? image

arguments like this: ----------------------- arguments ------------------------ accumulate_allreduce_grads_in_fp32 .............. True adam_beta1 ...................................... 0.9 adam_beta2 ...................................... 0.95 adam_eps ........................................ 1e-08 add_bias_linear ................................. False add_more_sp_tokens .............................. True add_position_embedding .......................... True add_qkv_bias .................................... True adlr_autoresume ................................. False adlr_autoresume_interval ........................ 1000 apply_layernorm_1p .............................. False apply_query_key_layer_scaling ................... False apply_residual_connection_post_layernorm ........ False apply_rope_fusion ............................... True async_tensor_model_parallel_allreduce ........... True attention_dropout ............................... 0.0 attention_softmax_in_fp32 ....................... False barrier_with_L1_time ............................ True bert_binary_head ................................ True bert_embedder_type .............................. megatron bert_load ....................................... None bf16 ............................................ True bias_dropout_fusion ............................. True bias_gelu_fusion ................................ False bias_swiglu_fusion .............................. True biencoder_projection_dim ........................ 0 biencoder_shared_query_context_model ............ False block_data_path ................................. None check_for_nan_in_loss_and_grad .................. False classes_fraction ................................ 1.0 clip_grad ....................................... 1.0 clone_scatter_output_in_embedding ............... True consumed_train_samples .......................... 0 consumed_valid_samples .......................... 0 context_parallel_size ........................... 1 data_cache_path ................................. None data_parallel_random_init ....................... False data_parallel_size .............................. 1 data_path ....................................... ['/data/oss_bucket_0/eleme_corpus/sft/general/longtext/EGPT-72B-longseq/v9/_chatml_document'] data_per_class_fraction ......................... 1.0 data_sharding ................................... True dataloader_type ................................. cyclic decoder_num_layers .............................. None decoder_seq_length .............................. None delay_grad_reduce ............................... True delay_param_gather .............................. False dino_bottleneck_size ............................ 256 dino_freeze_last_layer .......................... 1 dino_head_hidden_size ........................... 2048 dino_local_crops_number ......................... 10 dino_local_img_size ............................. 96 dino_norm_last_layer ............................ False dino_teacher_temp ............................... 0.07 dino_warmup_teacher_temp ........................ 0.04 dino_warmup_teacher_temp_epochs ................. 30 distribute_saved_activations .................... False distributed_backend ............................. nccl distributed_timeout_minutes ..................... 10 embedding_path .................................. None empty_unused_memory_level ....................... 0 enable_one_logger ............................... False encoder_num_layers .............................. 24 encoder_seq_length .............................. 2048 end_weight_decay ................................ 0.1 eod_mask_loss ................................... False eval_interval ................................... 10000 eval_iters ...................................... 10 evidence_data_path .............................. None exit_duration_in_mins ........................... None exit_interval ................................... None exit_on_missing_checkpoint ...................... False exit_signal_handler ............................. False expert_model_parallel_size ...................... 1 ffn_hidden_size ................................. 5504 finetune ........................................ True fp16 ............................................ False fp16_lm_cross_entropy ........................... False fp32_residual_connection ........................ False fp8 ............................................. None fp8_amax_compute_algo ........................... most_recent fp8_amax_history_len ............................ 1 fp8_interval .................................... 1 fp8_margin ...................................... 0 fp8_wgrad ....................................... True global_batch_size ............................... 8 gradient_accumulation_fusion .................... False group_query_attention ........................... False head_lr_mult .................................... 1.0 hidden_dropout .................................. 0.0 hidden_size ..................................... 2048 hysteresis ...................................... 2 ict_head_size ................................... None ict_load ........................................ None img_h ........................................... 224 img_w ........................................... 224 indexer_batch_size .............................. 128 indexer_log_interval ............................ 1000 inference_batch_times_seqlen_threshold .......... 512 init_method_std ................................. 0.01 init_method_xavier_uniform ...................... False initial_loss_scale .............................. 4294967296 iter_per_epoch .................................. 1250 kv_channels ..................................... 128 lazy_mpu_init ................................... None load ............................................ /data/oss_bucket_0/LLM/Qwen-1_8B-Chat-Pro-Megatron-train/ local_rank ...................................... None log_batch_size_to_tensorboard ................... True log_interval .................................... 10 log_learning_rate_to_tensorboard ................ True log_loss_scale_to_tensorboard ................... True log_memory_to_tensorboard ....................... False log_num_zeros_in_grad ........................... False log_params_norm ................................. False log_progress .................................... False log_throughput .................................. False log_timers_to_tensorboard ....................... True log_validation_ppl_to_tensorboard ............... True log_world_size_to_tensorboard ................... False loss_scale ...................................... None loss_scale_window ............................... 1000 lr .............................................. 0.0003 lr_decay_iters .................................. None lr_decay_samples ................................ 1000 lr_decay_style .................................. cosine lr_warmup_fraction .............................. None lr_warmup_init .................................. 0.0 lr_warmup_iters ................................. 0 lr_warmup_samples ............................... 100 make_vocab_size_divisible_by .................... 128 manual_gc ....................................... False manual_gc_eval .................................. True manual_gc_interval .............................. 0 mask_factor ..................................... 1.0 mask_prob ....................................... 0.15 mask_type ....................................... random masked_softmax_fusion ........................... True max_position_embeddings ......................... 2048 max_tokens_to_oom ............................... 12000 merge_file ...................................... qwen_15w.tiktoken micro_batch_size ................................ 1 min_loss_scale .................................. 1.0 min_lr .......................................... 3e-05 mock_data ....................................... False moe_aux_loss_coeff .............................. 0.0 moe_grouped_gemm ................................ False moe_input_jitter_eps ............................ None moe_router_load_balancing_type .................. aux_loss moe_router_topk ................................. 2 moe_token_dropping .............................. False moe_z_loss_coeff ................................ None nccl_communicator_config_path ................... None no_load_optim ................................... True no_load_rng ..................................... None no_persist_layer_norm ........................... False no_save_optim ................................... None no_save_rng ..................................... None norm_epsilon .................................... 1e-05 normalization ................................... RMSNorm num_attention_heads ............................. 16 num_channels .................................... 3 num_classes ..................................... 1000 num_experts ..................................... None num_layers ...................................... 24 num_layers_per_virtual_pipeline_stage ........... None num_query_groups ................................ 1 num_workers ..................................... 2 one_logger_entity ............................... hwinf_dcm one_logger_project .............................. e2e-tracking one_logger_run_name ............................. None onnx_safe ....................................... None openai_gelu ..................................... False optimizer ....................................... adam output_bert_embeddings .......................... False overlap_grad_reduce ............................. False overlap_p2p_comm ................................ False overlap_param_gather ............................ False override_opt_param_scheduler .................... False params_dtype .................................... torch.bfloat16 patch_dim ....................................... 16 perform_initialization .......................... True pipeline_model_parallel_size .................... 1 pipeline_model_parallel_split_rank .............. None position_embedding_type ......................... rope profile ......................................... False profile_ranks ................................... [0] profile_step_end ................................ 12 profile_step_start .............................. 10 query_in_block_prob ............................. 0.1 rampup_batch_size ............................... None rank ............................................ 0 recompute_granularity ........................... selective recompute_method ................................ None recompute_num_layers ............................ None reset_attention_mask ............................ False reset_position_ids .............................. False retriever_report_topk_accuracies ................ [] retriever_score_scaling ......................... False retriever_seq_length ............................ 256 retro_add_retriever ............................. False retro_attention_gate ............................ 1 retro_cyclic_train_iters ........................ None retro_encoder_attention_dropout ................. 0.1 retro_encoder_hidden_dropout .................... 0.1 retro_encoder_layers ............................ 2 retro_num_neighbors ............................. 2 retro_num_retrieved_chunks ...................... 2 retro_return_doc_ids ............................ False retro_verify_neighbor_count ..................... True retro_workdir ................................... None rotary_interleaved .............................. False rotary_percent .................................. 1.0 rotary_seq_len_interpolation_factor ............. None sample_rate ..................................... 1.0 save ............................................ /data/oss_bucket_0/LLM/Qwen-1_8B-Chat-Pro-Megatron-train-TE/ save_interval ................................... 10000 scatter_gather_tensors_in_pipeline .............. True seed ............................................ 3407 seq_length ...................................... 2048 sequence_parallel ............................... False sgd_momentum .................................... 0.9 short_seq_prob .................................. 0.1 skip_train ...................................... False spec ............................................ None split ........................................... 98,2,0 squared_relu .................................... False standalone_embedding_stage ...................... False start_weight_decay .............................. 0.1 swiglu .......................................... True swin_backbone_type .............................. tiny tensor_model_parallel_size ...................... 1 tensorboard_dir ................................. None tensorboard_log_interval ........................ 1 tensorboard_queue_size .......................... 1 test_data_path .................................. None timing_log_level ................................ 0 timing_log_option ............................... minmax titles_data_path ................................ None tokenizer_model ................................. None tokenizer_type .................................. QWenTokenizer tp_comm_bulk_dgrad .............................. True tp_comm_bulk_wgrad .............................. True tp_comm_overlap ................................. False tp_comm_overlap_cfg ............................. None tp_comm_split_ag ................................ True tp_comm_split_rs ................................ True train_data_path ................................. None train_iters ..................................... None train_samples ................................... 1000 transformer_impl ................................ local transformer_pipeline_model_parallel_size ........ 1 untie_embeddings_and_output_weights ............. True use_checkpoint_args ............................. False use_checkpoint_opt_param_scheduler .............. False use_cpu_initialization .......................... None use_distributed_optimizer ....................... True use_flash_attn .................................. True use_mcore_models ................................ True use_one_sent_docs ............................... False use_ring_exchange_p2p ........................... False use_rotary_position_embeddings .................. False valid_data_path ................................. None variable_seq_lengths ............................ False virtual_pipeline_model_parallel_size ............ None vision_backbone_type ............................ vit vision_pretraining .............................. False vision_pretraining_type ......................... classify vocab_extra_ids ................................. 0 vocab_file ...................................... qwen_15w.tiktoken vocab_size ...................................... None wandb_exp_name .................................. wandb_project ................................... wandb_save_dir .................................. weight_decay .................................... 0.1 weight_decay_incr_style ......................... constant world_size ...................................... 1 -------------------- end of arguments ---------------------

Expected behavior The error seems to be due to a mismatch in the naming convention of the parameters. However, I can't figure out why using --use-mcore-models or not would lead to different parameter naming. Isn't this supposed to be seamlessly compatible?

Stack trace/logs If applicable, add the stack trace or logs from the time of the error.

Environment (please complete the following information):

Proposed fix If you have a proposal for how to fix the issue state it here or link to a PR.

Additional context Add any other context about the problem here.

ethanhe42 commented 4 months ago

you need to convert checkpoint from legacy to mcore. https://github.com/NVIDIA/Megatron-LM/tree/main/tools/checkpoint