goombalab / phi-mamba

Official implementation of Phi-Mamba. A MOHAWK-distilled model (Transformers to SSMs: Distilling Quadratic Knowledge to Subquadratic Models)
https://arxiv.org/abs/2408.10189
68 stars 3 forks source link

Very high loss for stage 1 #3

Open ostix360 opened 3 days ago

ostix360 commented 3 days ago

Hi, Thanks for your work. I want to create a general script that can transform any transformer model into an hybrid version. I'm struggling to get correct loss value for the stage 1 and I don't think the model learns... I almost copy pasted your code (training and discrete mamba class). Here is an example loss I got:

Iter 13 Loss: 14136578071405.715
Iter 14 Loss: 77309411328.0
Iter 15 Loss: 42949672960.0
Iter 16 Loss: 565631853.7142857
Iter 17 Loss: 240518168576.0

Here is the training loop:

_, student_modules = self.find_modules()
data_collator = DataCollatorForLanguageModeling(tokenizer=self.tokenizer, mlm=False)
dataloader = DataLoader(self.dataset, batch_size=batch_size, collate_fn=data_collator)

child_module = nn.ModuleList(student_modules).to("cuda", dtype=torch.bfloat16)
self.teacher_model.to("cuda", dtype=torch.bfloat16)
optimizer = AdamW(child_module.parameters(), lr=1e-6)
# Stage 1 skeleton
self.student_model.requires_grad_(True)
for idx, data in enumerate(dataloader):
    input_ids = data["input_ids"].to("cuda")

    _, seq_len = input_ids.size()

    teacher_outputs = self.teacher_model(
        input_ids=input_ids,
        output_hidden_states=True,
        output_attentions=True,
        use_cache=False,
    )
    for layer_idx, student_layer in enumerate(child_module):
        optimizer.zero_grad()
        student_input = teacher_outputs.hidden_states[layer_idx]
        # Forward pass
        student_output = student_layer(
            hidden_states=student_input,
            return_mixer_matrix=True,
        )
        transfer_matrix = student_output[1][
                          ..., :seq_len, :seq_len
                          ]  # because of our Mamba2 chunking implementation
        attn_matrix = teacher_outputs.attentions[layer_idx]

        assert transfer_matrix.size() == attn_matrix.size()

        loss = torch.linalg.matrix_norm(
            transfer_matrix - attn_matrix, ord="fro"
        ).mean()

        loss.backward()

        nn.utils.clip_grad_norm_(student_layer.parameters(), 1.5)
        optimizer.step()
    print(f"Iter {idx} Loss: {loss.item() / len(child_module)}")
self.check_layer_error()

The only thing that I'm not so sure is, what you meant by n_qk_heads and n_v_heads . I'm trying to convert small qwen model (the 2.0 and now the 2.5) 1.5B and in the config file there is:

"num_attention_heads": 12,
"num_key_value_heads": 2,

and if I set n_qk_heads = num_key_value_heads and n_v_heads = num_attention_heads I get an assertion error: assert A_log.shape == (batch_size, length, n_heads) So I decided to set both to the num_attention_heads Thanks in advance!

AvivBick commented 3 days ago

The n_qk_heads parameter represents the number of query (C) and key (B) projection heads, while n_v_heads refers to the number of value (X) projection heads. This notation is aligned with the State Space Duality described in Mamba2.

In this context, controlling the number of query and key heads dependently leads to a multi-value architecture, rather than a multi-query one. Note that n_qk_heads (query-key heads) is set to align with num_key_value_heads, which tries to match QK to KV projections, but they are not inherently compatible.

I would suggest that:

AvivBick commented 3 days ago

Regarding the high loss, make sure you’re using our latest matrix mixer implementation and that everything is properly aligned. I’d also suggest starting with distilling one Transformer into another as a quick sanity check. This kind of distillation is usually fast and straightforward.

tGhattas commented 1 day ago

@ostix360 also I would make sure to freeze the rest of the student weights

ostix360 commented 1 day ago

Thanks for your quick answer. Indeed, I forgot to update the materialize_mixer and that's what causes the error. And Yes I took care of freezing the rest of the weights. Ok so only the pure multi head attention will work with your implementation otherwise I should modify the mamba architectures to make it compatible or let's say equivalent to the grouped query attention transformers?

AvivBick commented 1 day ago

yep