kevinyaobytedance / llm_unlearn

LLM Unlearning
MIT License
99 stars 13 forks source link

Is this the right way to compute KL divergence? #4

Open himalalps opened 3 months ago

himalalps commented 3 months ago

The code in utils.py related to compute KL divergence is as follows, but I think maybe this is not the KL divergence but cross entropy.

https://github.com/kevinyaobytedance/llm_unlearn/blob/647f309519f91c29d87e62cf63d9a43759810040/utils.py#L199-L203

Why not directly use PyTorch KLDivLoss?