hustvl / WeakTr

WeakTr: Exploring Plain Vision Transformer for Weakly-supervised Semantic Segmentation
MIT License
120 stars 2 forks source link

about gradient clip #9

Closed Jingfeng-Tang closed 1 year ago

Jingfeng-Tang commented 1 year ago

Thanks for you brilliant work in WSSS! I have a question about gradient clip. https://github.com/hustvl/WeakTr/blob/main/OnlineRetraining/segm/model/decoder.py I think 143-147 code has achieved the result(masked gradient patches) of the formula(12)and(13) in the paper. line 143-147

local_mean = torch.maximum(local_mean, mean_loss)
local_mean = torch.repeat_interleave(local_mean, b, dim=0)  
local_mean = torch.repeat_interleave(local_mean, self.patch_size, dim=1)
local_mean = torch.repeat_interleave(local_mean, self.patch_size, dim=2)    

In the paper, I did not find an explanation of the following code. Can you give an explanation? Thanks. line149-151

clamp_loss = ori_loss - local_mean
clamp_loss = torch.clamp(clamp_loss, None, 0) 
loss = clamp_loss + local_mean
Yingyue-L commented 1 year ago

We appreciate your recognition of our work. For your question, firstly, the mean_loss and local_mean correspond to ${\lambdai}$ and $\lambda{global}$ respectively, and we just achieved the $max({\lambdai}, \lambda{global})$ of the formula(12) for the line 143-147. The line 149-151 are truly complete the formula(12) and (13):

  1. In line 149: we obtained a clipping mask clamp_loss, where the relatively large gradients are greater than zero, otherwise they are less than or equal to zero.
  2. In line 150: we use torch.clamp to perform gradient clipping in the regions where the values of clamp_loss are greater than zero.
  3. In line 151: we simply add the value to obtain the actual loss value.
Jingfeng-Tang commented 1 year ago

Thanks.