songmzhang / DSKD

Repo for Paper "Dual-Space Knowledge Distillation for Large Language Models".
29 stars 3 forks source link

关于 AKL 的计算 #1

Closed wutaiqiang closed 2 months ago

wutaiqiang commented 2 months ago

你好!非常棒的工作!

注意到,AKL 的计算是:

https://github.com/songmzhang/DSKD/blob/de4c3b3dda974139b813acb0c4eeb457b59a70ee/code/criterions/various_divergence.py#L142

不过,这么来算可能有点问题。

AKL 的 mask 选取原则是:

image

也就是,使用 选取最少的元素 使得元素和大于某个阈值 mu

在贵代码中,应该是假定了 alpha=1-mu, 假定 sorted_teacher_probs 是 0.1 0.2 0.3 0.4,那么cum_teacher_probs是 0.1 0.3 0.6 1.0。那么有

mu=0.0 时,alpha=1.0,那tail_mask 是 1 1 1 0,也就是 0.1 0.2 0.3 || 0.4 (尾||头),符合 mu=0.05 时,alpha=0.95,那tail_mask 是 1 1 1 0,也就是 0.1 0.2 0.3 || 0.4 (尾||头),符合 mu=0.3 时,alpha=0.7,那tail_mask 是 1 1 1 0,也就是 0.1 0.2 0.3 || 0.4 (尾||头),符合 mu=0.4 时,alpha=0.6,那tail_mask 是 1 1 0 0,也就是 0.1 0.2 || 0.3 0.4 (尾||头),不符合 mu=0.5 时,alpha=0.5,那tail_mask 是 1 1 0 0,也就是 0.1 0.2 || 0.3 0.4 (尾||头),符合 mu=0.6 时,alpha=0.4,那tail_mask 是 1 1 0 0,也就是 0.1 0.2 || 0.3 0.4 (尾||头),符合 mu=0.7 时,alpha=0.3,那tail_mask 是 1 0 0 0,也就是 0.1 || 0.2 0.3 0.4 (尾||头),不符合 mu=0.8 时,alpha=0.2,那tail_mask 是 1 0 0 0,也就是 0.1 || 0.2 0.3 0.4 (尾||头),符合 mu=1.0 时,alpha=0.0,那tail_mask 是 0 0 0 0,也就是 || 0.1 0.2 0.3 0.4 (尾||头),符合

假设把https://github.com/songmzhang/DSKD/blob/de4c3b3dda974139b813acb0c4eeb457b59a70ee/code/criterions/various_divergence.py#L167 的 lt 改为 le,那么:

mu=0.05 时,alpha=0.95,那tail_mask 是 1 1 1 0,也就是 0.1 0.2 0.3 || 0.4 ( 尾||头),符合 mu=0.3 时,alpha=0.7,那tail_mask 是 1 1 1 0,也就是 0.1 0.2 0.3 || 0.4 (尾||头),符合 mu=0.4 时,alpha=0.6,那tail_mask 是 1 1 1 0,也就是 0.1 0.2 0.3 || 0.4 (尾||头),符合 mu=0.5 时,alpha=0.5,那tail_mask 是 1 1 0 0,也就是 0.1 0.2 || 0.3 0.4 (尾||头),符合 mu=0.6 时,alpha=0.4,那tail_mask 是 1 1 0 0,也就是 0.1 0.2 || 0.3 0.4 (尾||头),符合 mu=0.7 时,alpha=0.3,那tail_mask 是 1 1 0 0,也就是 0.1 0.2 || 0.3 0.4 (尾||头),符合 mu=0.8 时,alpha=0.2,那tail_mask 是 1 0 0 0,也就是 0.1 || 0.2 0.3 0.4 (尾||头),符合 mu=1.0 时,alpha=0.0,那tail_mask 是 0 0 0 0,也就是 || 0.1 0.2 0.3 0.4 (尾||头),符合

逻辑上来说,使用 选取最少的元素 使得元素和>=某个阈值 mu 等价于 使用 选取最多的元素 使得元素和<=1-mu, 而不是 使用 选取最多的元素 使得元素和<1-mu

所以正确的写法应该是:

tail_mask = cum_teacher_probs.le(alpha).float()

实际使用中, mu=0.5,也就是 alpha=0.5,那么当logit 是 0.1 0.2 0.2 0.2 0.3 时,就会出现类似上述的问题,模型可能错误划分为 0.1 0.2 || 0.2 0.2 0.3,造成比例的失真。

songmzhang commented 2 months ago

你好!非常棒的工作!

注意到,AKL 的计算是:

https://github.com/songmzhang/DSKD/blob/de4c3b3dda974139b813acb0c4eeb457b59a70ee/code/criterions/various_divergence.py#L142

不过,这么来算可能有点问题。

AKL 的 mask 选取原则是: image

也就是,使用 选取最少的元素 使得元素和大于某个阈值 mu

在贵代码中,应该是假定了 alpha=1-mu, 假定 sorted_teacher_probs 是 0.1 0.2 0.3 0.4,那么cum_teacher_probs是 0.1 0.3 0.6 1.0。那么有

mu=0.0 时,alpha=1.0,那tail_mask 是 1 1 1 0,也就是 0.1 0.2 0.3 || 0.4 (尾||头),符合 mu=0.05 时,alpha=0.95,那tail_mask 是 1 1 1 0,也就是 0.1 0.2 0.3 || 0.4 (尾||头),符合 mu=0.3 时,alpha=0.7,那tail_mask 是 1 1 1 0,也就是 0.1 0.2 0.3 || 0.4 (尾||头),符合 mu=0.4 时,alpha=0.6,那tail_mask 是 1 1 0 0,也就是 0.1 0.2 || 0.3 0.4 (尾||头),不符合 mu=0.5 时,alpha=0.5,那tail_mask 是 1 1 0 0,也就是 0.1 0.2 || 0.3 0.4 (尾||头),符合 mu=0.6 时,alpha=0.4,那tail_mask 是 1 1 0 0,也就是 0.1 0.2 || 0.3 0.4 (尾||头),符合 mu=0.7 时,alpha=0.3,那tail_mask 是 1 0 0 0,也就是 0.1 || 0.2 0.3 0.4 (尾||头),不符合 mu=0.8 时,alpha=0.2,那tail_mask 是 1 0 0 0,也就是 0.1 || 0.2 0.3 0.4 (尾||头),符合 mu=1.0 时,alpha=0.0,那tail_mask 是 0 0 0 0,也就是 || 0.1 0.2 0.3 0.4 (尾||头),符合

假设把

https://github.com/songmzhang/DSKD/blob/de4c3b3dda974139b813acb0c4eeb457b59a70ee/code/criterions/various_divergence.py#L167

的 lt 改为 le,那么: mu=0.05 时,alpha=0.95,那tail_mask 是 1 1 1 0,也就是 0.1 0.2 0.3 || 0.4 ( 尾||头),符合 mu=0.3 时,alpha=0.7,那tail_mask 是 1 1 1 0,也就是 0.1 0.2 0.3 || 0.4 (尾||头),符合 mu=0.4 时,alpha=0.6,那tail_mask 是 1 1 1 0,也就是 0.1 0.2 0.3 || 0.4 (尾||头),符合 mu=0.5 时,alpha=0.5,那tail_mask 是 1 1 0 0,也就是 0.1 0.2 || 0.3 0.4 (尾||头),符合 mu=0.6 时,alpha=0.4,那tail_mask 是 1 1 0 0,也就是 0.1 0.2 || 0.3 0.4 (尾||头),符合 mu=0.7 时,alpha=0.3,那tail_mask 是 1 1 0 0,也就是 0.1 0.2 || 0.3 0.4 (尾||头),符合 mu=0.8 时,alpha=0.2,那tail_mask 是 1 0 0 0,也就是 0.1 || 0.2 0.3 0.4 (尾||头),符合 mu=1.0 时,alpha=0.0,那tail_mask 是 0 0 0 0,也就是 || 0.1 0.2 0.3 0.4 (尾||头),符合

逻辑上来说,使用 选取最少的元素 使得元素和>=某个阈值 mu 等价于 使用 选取最多的元素 使得元素和<=1-mu, 而不是 使用 选取最多的元素 使得元素和<1-mu

所以正确的写法应该是:

tail_mask = cum_teacher_probs.le(alpha).float()

实际使用中, mu=0.5,也就是 alpha=0.5,那么当logit 是 0.1 0.2 0.2 0.2 0.3 时,就会出现类似上述的问题,模型可能错误划分为 0.1 0.2 || 0.2 0.2 0.3,造成比例的失真。

非常感谢您的提醒!我会尽快更正这个错误并更新相应的实验结果!