hrzhang1123 / DTFD-MIL

MIT License
123 stars 19 forks source link

wrong gradient calculation code? #8

Open Treeboy2762 opened 1 year ago

Treeboy2762 commented 1 year ago

Hi @hrzhang1123,

With the torch version 1.12, the code raises

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [128, 1]], which is output 0 of AsStridedBackward0, is at version 6; expected version 5 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

at

 64 loss1 = ce_cri(gSlidePred, tslideLabel).mean()
 65 optimizer1.zero_grad()

---> 66 loss1.backward() 67 torch.nn.utils.clip_gradnorm(attCls.parameters(), 5) 68 optimizer1.step()

This can be resolved by calling optimizer0.step() right before optimizer1.step(). This makes sense, because updating the weights before performing backward propagations of loss1 would result in incorrect weights being used. Can you consider reviewing on this?

Tonyboy999 commented 1 year ago

Hi @Treeboy2762 , I also met this error when I was using newer version of Pytorch. The bug would disappear if I used pytorch 1.4.

Dootmaan commented 1 year ago

Hi @Treeboy2762 , I also met this error when I was using newer version of Pytorch. The bug would disappear if I used pytorch 1.4.

Thank you. I also solved this problem by changing to use pytorch 1.4.

Treeboy2762 commented 1 year ago

@Furyboyy Thanks! This temporarily solves the problem, but I am not sure if it's an appropriate solution..

jasonyin20 commented 1 year ago

I have an solution to adjust the position of optimizer0.step(), I can run the code. image

weiaicunzai commented 6 months ago

I have an solution to adjust the position of optimizer0.step(), I can run the code.

Or we can use detach() method at: slide_pseudo_feat.append(af_inst_feat.detach()) slide_pseudo_feat.append(max_inst_feat.detach()) slide_pseudo_feat.append(MaxMin_inst_feat.detach()) also works

weiaicunzai commented 4 months ago

I have an solution to adjust the position of optimizer0.step(), I can run the code.

Or we can use detach() method at: slide_pseudo_feat.append(af_inst_feat.detach()) slide_pseudo_feat.append(max_inst_feat.detach()) slide_pseudo_feat.append(MaxMin_inst_feat.detach()) also works

I haven't tested the performance of this code yet, but not bugs are reported