LiuXiaoxuanPKU / OSD

37 stars 3 forks source link

distil_trainer.py #8

Open ScorpionCG opened 1 month ago

ScorpionCG commented 1 month ago

![Uploading image.png…]() distil_trainer.py In this code snippet, for each invocation of generate_one when generating tokens, the past_key_values used are those from either the student_key_values or teacher_key_values that were decided upon when generating the first token, rather than the kv (key-value pairs) returned each time a new token is generated.why not use the new generated kv?

ScorpionCG commented 1 month ago
    for i in range(max_new_tokens - 1):
        sample_model, past_key_values = (student_model, student_key_values) if random.random(
        ) < mix_ratio else (teacher_model, teacher_key_values)
        next_token, _ = generate_one(sample_model, input_ids,
                                     attention_mask, past_key_values)
        input_ids = torch.cat([input_ids, next_token], dim=-1)
        attention_mask = torch.cat([attention_mask, torch.ones(
            bsz, 1, dtype=torch.long, device='cuda')], dim=1)