arcee-ai / DistillKit

An Open Source Toolkit For LLM Distillation
GNU Affero General Public License v3.0
337 stars 36 forks source link

added some initial logic to load the teacher logits #1

Open shamanez opened 2 months ago

shamanez commented 2 months ago
mertege commented 1 week ago

Hi @shamanez, I created offline logits, and I saved them as "logits". I then loaded this dataset and checked that I could see the "logits" key in the relevant dataset before SFTTrainer. However, in the "compute_loss" function, "inputs" only contains the "input_ids" and "attention_mask" keys. Since there are none of the "logits" in "compute_loss", I cannot get the teacher logits. I think that trl SFTTrainer only gets inputs with "input_ids" and "attantion_mask" keys. Have you encountered this kind of problem?

I appreciate any help you can provide.

shamanez commented 1 week ago

can you try remove_unused_columns=False in the SFT trainer.

mertege commented 1 week ago

can you try remove_unused_columns=False in the SFT trainer.

Thanks @shamanez it works.

shing100 commented 6 days ago

Hello, I tried to make teacher data through the code you posted. I tried to make it using the llama3.1-70B-Inst model, but only RAM memory goes up and GPUs(80GB*8) doesn't work and I get an error.

I appreciate any help you can provide.

used commend : accelerate launch dataset_with_teacher_logits.py