cschenxiang / DRSformer

Learning A Sparse Transformer Network for Effective Image Deraining (CVPR 2023)
265 stars 14 forks source link

关于topk可微分问题 #24

Open lijun2005 opened 6 months ago

lijun2005 commented 6 months ago

作者您好,请问文章在使用topk生成mask进而计算注意力时,有考虑过此操作的可微性? 这种离散的操作我认为是不可微分的,那么神经网络在训练时是如何处理这个问题呢?换句话说,topk是如何奏效的?因为不可微分也许会干扰神经网络的训练,但是文章采用了topk效果非常好。 对这个问题我很困惑,还请作者在百忙之间解答一下我的疑惑,谢谢!

maoxiaowei97 commented 3 months ago

兄弟,argmax不可微,topk是可微的,取出来的数据在计算图中可以保留前后序计算

lijun2005 commented 3 months ago

额,按你这么说现在所有在做可微topk的都不用做了,只需要torch.topk了😂😂

---原始邮件--- 发件人: @.> 发送时间: 2024年7月19日(周五) 晚上10:04 收件人: @.>; 抄送: @.**@.>; 主题: Re: [cschenxiang/DRSformer] 关于topk可微分问题 (Issue #24)

兄弟,argmax不可微,topk是可微的,取出来的数据在计算图中可以保留前后序计算

— Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you authored the thread.Message ID: @.***>

lijun2005 commented 3 months ago

并且作者的做法是根据topk得出的索引值,根据索引值进行mask,这个索引值还能可微?

maoxiaowei97 commented 3 months ago

并且作者的做法是根据topk得出的索引值,根据索引值进行mask,这个索引值还能可微?

抱歉之前我没仔细看这部分代码。

我理解是不是这里的 torch.where 操作是可微的,输出 attn1 的梯度可以通过 attn 进行传播。尽管掩码 mask1 本身不可微,但它的作用是选择性地保留 attn 中的值,这种选择不会中断计算图的构建。 attn1 = torch.where(mask1 > 0, attn, torch.full_like(attn, float('-inf')))

掩码是根据attn的值确定的,且通过不可微操作生成,这些操作不会记录在计算图中,其本身不参与反向传播的梯度计算。我理解掩码和输入张量都可以被视为叶子节点。

maoxiaowei97 commented 3 months ago

额,按你这么说现在所有在做可微topk的都不用做了,只需要torch.topk了😂😂 ---原始邮件--- 发件人: @.> 发送时间: 2024年7月19日(周五) 晚上10:04 收件人: @.>; 抄送: @.**@.>; 主题: Re: [cschenxiang/DRSformer] 关于topk可微分问题 (Issue #24) 兄弟,argmax不可微,topk是可微的,取出来的数据在计算图中可以保留前后序计算 — Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you authored the thread.Message ID: @.***>

你说得没错,torch.topk, torch.argmax取索引的操作是不可微分,如果取得的结果直接用于求loss,比如用文本摘要生成的各个词(greedy-argmax, beam search-topk)组成的句子去计算一些评价指标,如BLEU, ROUGE,这样计算得到的评价指标是不可以求微分的。但作者这里topk生成的掩码应该是不参与反向传播的梯度计算

lijun2005 commented 3 months ago

嗯对,我的问题就是比如在网络训练初期模型还不太稳定的时候,这种计算出的掩码是否会有点问题。很感谢跟您一起讨论,这篇文章的做法来源于一篇ECCV的文章,我相信本文作者也就是拿过来试了一试,并没有对里面的可微性进行分析。那篇ECCV的文章的动机是利用topk获得掩码后得到的注意力矩阵为稀疏矩阵从而加速Transformer计算。

---原始邮件--- 发件人: @.> 发送时间: 2024年7月20日(周六) 中午11:10 收件人: @.>; 抄送: @.**@.>; 主题: Re: [cschenxiang/DRSformer] 关于topk可微分问题 (Issue #24)

并且作者的做法是根据topk得出的索引值,根据索引值进行mask,这个索引值还能可微?

抱歉之前我没仔细看这部分代码。

我理解是不是这里的 torch.where 操作是可微的,输出 attn1 的梯度可以通过 attn 进行传播。尽管掩码 mask1 本身不可微,但它的作用是选择性地保留 attn 中的值,这种选择不会中断计算图的构建。 attn1 = torch.where(mask1 > 0, attn, torch.full_like(attn, float('-inf')))

掩码是根据attn的值确定的,且通过不可微操作生成,这些操作不会记录在计算图中,其本身不参与反向传播的梯度计算。我理解掩码和输入张量都可以被视为叶子节点。

— Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you authored the thread.Message ID: @.***>