kevinyaobytedance / llm_unlearn

LLM Unlearning
MIT License
125 stars 16 forks source link

Is this the right way to compute KL divergence? #4

Open himalalps opened 7 months ago

himalalps commented 7 months ago

The code in utils.py related to compute KL divergence is as follows, but I think maybe this is not the KL divergence but cross entropy.

https://github.com/kevinyaobytedance/llm_unlearn/blob/647f309519f91c29d87e62cf63d9a43759810040/utils.py#L199-L203

Why not directly use PyTorch KLDivLoss?