NVIDIA / NeMo

A scalable generative AI framework built for researchers and developers working on Large Language Models, Multimodal, and Speech AI (Automatic Speech Recognition and Text-to-Speech)
https://docs.nvidia.com/nemo-framework/user-guide/latest/overview.html
Apache License 2.0
12.29k stars 2.55k forks source link

Drastic difference between .nemo and HF checkpoint #11360

Open rahul-sarvam opened 1 week ago

rahul-sarvam commented 1 week ago

Describe the bug

I have trained a llama-like model with nemo using the below model config:

model:
  mcore_gpt: True
  micro_batch_size: 1
  global_batch_size: 512
  tensor_model_parallel_size: 1
  pipeline_model_parallel_size: 1
  virtual_pipeline_model_parallel_size: null
  context_parallel_size: 1
  encoder_seq_length: 8192
  max_position_embeddings: ${.encoder_seq_length}
  num_layers: 28
  hidden_size: 2048
  ffn_hidden_size: 11008
  num_attention_heads: 16
  init_method_std: 0.02
  use_scaled_init_method: True
  hidden_dropout: 0.0
  attention_dropout: 0.0
  ffn_dropout: 0.0
  kv_channels: null
  apply_query_key_layer_scaling: True
  normalization: 'rmsnorm'
  layernorm_epsilon: 1e-6
  do_layer_norm_weight_decay: False
  make_vocab_size_divisible_by: 128
  pre_process: True
  post_process: True
  persist_layer_norm: True
  bias: False
  activation: 'fast-swiglu'
  headscale: False
  transformer_block_type: 'pre_ln'
  openai_gelu: False
  normalize_attention_scores: True
  position_embedding_type: 'rope'
  rotary_percentage: 1.0
  attention_type: 'multihead'
  share_embeddings_and_output_weights: False
  overlap_p2p_comm: False
  batch_p2p_comm: True
  num_query_groups: 8
  rotary_base: 10000.0

The model works well when I run inference using the nemo checkpoint (script). But the converted checkpoint (script) drastically drops in performance. Any ideas why this might be happening? My only hunch is that apply_query_key_layer_scaling=True in nemo, which might not be the case in HF.

Environment details https://docs.nvidia.com/nemo-framework/user-guide/latest/softwarecomponentversions.html#nemo-framework-24-05

rahul-sarvam commented 4 days ago

I have compared a bunch of things between the 2 models and looks like there is a large difference between the logits of the 2 models.

nemo_model = MegatronGPTModel.restore_from(
    nemo_path,
    trainer=dummy_trainer,
    override_config_path=model_config,
    map_location=map_location
)

# Load HuggingFace model
hf_model = AutoModelForCausalLM.from_pretrained(
    hf_path,
    local_files_only=True,
    torch_dtype=torch.bfloat16 # nemo_model.dtype
)

# Load tokenizer
tokenizer = LlamaTokenizer.from_pretrained(tokenizer_path, legacy=False)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id

# Move models to device
nemo_model = nemo_model.to(device)
hf_model = hf_model.to(device)

# Set both models to eval mode
nemo_model.eval()
hf_model.eval()

# Create random input ids
input_ids = torch.randint(
    100, 1000,
    (test_batch_size, test_seq_length),
    device=device
)
attention_mask = torch.ones_like(input_ids)

with torch.no_grad():
    # NeMo forward pass
    nemo_output = nemo_model(
        tokens=input_ids,
        text_position_ids=torch.arange(test_seq_length, device=device),
        attention_mask=attention_mask,
        labels=None
    )

    # HF forward pass
    hf_output = hf_model(
        input_ids=input_ids,
        attention_mask=attention_mask,
        output_hidden_states=True,
        return_dict=True
    ).logits

# Compare logits
logits_match = torch.allclose(
    nemo_output,
    hf_output,
    rtol=rtol,
    atol=atol
)

metrics['logits_max_diff'] = float(
    torch.max(torch.abs(nemo_output - hf_output)).cpu()
)

Output:

Conversion test results:
Logits match: False (max diff: 4.91e+00)
  Parameters match: True (max diff: 0.00e+00)
  Generation match: 0.0
    Sample generation comparison:
      Input text: '<s>[INST] Hello [/INST]\n'
      NeMo output: "<s>[INST] Hello [/INST]\n Hello. It's nice to meet you. Is there something I can help you with or"
      HF output: '<s> [INST] Hello [/INST]\n Hello. ನಿಮ್ಮನ್ನ ಭೇಟಿ ಮಾಡಿ ಸಂತೋಷ ಆಯ್ತು. ನಿಮಗೆ ಏನ'
Number of parameters match: 1.0 (Nemo: 2525087744, HF: 2525087744)
❌ Conversion test failed!

I am not able to pinpoint why this is happening. Any pointers will be helpful.