zjunlp / EasyEdit

[ACL 2024] An Easy-to-use Knowledge Editing Framework for LLMs.
https://zjunlp.github.io/project/KnowEdit
MIT License
1.74k stars 210 forks source link

line 118 in ft_main.py #249

Closed SH9959 closed 3 months ago

SH9959 commented 4 months ago

Line 118 of ft_main.py:

loss_mask = target_ids ! = tok.unk_token_id

The handling of tok.unk_token_id as None seems to be omitted here. If tok.unk_token_id is None, then on line 201 loss = -(torch.gather(probs, 1, target_ids) loss_mask).sum(1) / loss_mask.sum(1) will report an error of .sum()*

SH9959 commented 4 months ago

I'm not quite sure if it would be better to modify it like the following.😊

  if tok.unk_token_id is None:
      tok.unk_token_id = tok.pad_token_id
  loss_mask = target_ids != tok.unk_token_id
XeeKee commented 3 months ago

Thank you very much for your advice; we have updated the code. Wishing you a pleasant life!