SineZHAN / deepALplus

This is a toolbox for Deep Active Learning, an extension from previous work https://github.com/ej0cl6/deep-active-learning (DeepAL toolbox).
MIT License
170 stars 24 forks source link

Issue in nets_waal.py file #5

Open frankiz22 opened 1 year ago

frankiz22 commented 1 year ago

Dear Repository owners,

I would like to use your deepALplus to do experiments with Deep Active Learning, but i got a comprehension problem with net_waal.py. Indeed, at line 81, you recompute the features for labeled and unlabeled data in "with torch.no_grad" loop. Then you compute the gradient penalty and add it to the loss. Since you used the with torch.no_grad loop, the contibution of the gradient penalty when updating the weight of your features_extractor will be null. and Since at line 64 you set the requires_grad=False for the discriminator the weigths of the discriminator will not be update.

I would like to know why you recomputed the features in the "with torch.no_grad" loop since it seems to make the gradient penalty to have no impact when updating the weight of your entire model.
Thank you.

SineZHAN commented 1 year ago

We only update the discriminator in second step. See line > 97.

BenjaminMidtvedt commented 8 months ago

@SineZHAN I am confused about the same thing.

You calculate gradient penalty twice, line 86 and line 115. The second updates the discriminator. However, the first calculation has no gradients anywhere, so I don't see how it could result in any weight updates. It would make sense to me if the lb_z and unlb_z in line 83 and 84 were not calculated with torch.no_grad, such that the feature network would be updated by the gradient penality.