idiap / importance-sampling

Code for experiments regarding importance sampling for training neural networks
Other
320 stars 60 forks source link

Some confusion about the fast grads calculation when converting to Pytorch. #36

Open ShunLu91 opened 2 years ago

ShunLu91 commented 2 years ago

Hello,

Great thanks to you for your great efforts. After reading your paper and code, I found that it's in fact a nice and solid work and I really enjoy it.

To utilize this method in my model training, I try to implement your method using the Pytorch framework. I notice that you use the following code to calculate the gradient norm in a fast mode:

if self.fast:
    grads = K.sqrt(sum([
        self._sum_per_sample(K.square(g))
        for g in K.gradients(losses, self.parameter_list)
    ]))

As far as I am concerned, this line of the code [self._sum_per_sample(K.square(g)) for g in K.gradients(losses, self.parameter_list)] has computed the gradients square and summed them per sample. I am confused about why not directly use K.sqrt() function to get the gradient norm of each sample but introduce another sum() function behind the K.sqrt()?

Besides, I have checked the results of sum([self._sum_per_sample(K.square(g)) for g in K.gradients(losses, self.parameter_list)]) and [self._sum_per_sample(K.square(g)) for g in K.gradients(losses, self.parameter_list)], and found that they were equal, which is really amazing. And if I remove the sum() function behind the K.sqrt(), it will raise the data type error. Therefore, does this sum() function only convert the data type and not perform summation?

Expect your reply and I will share my Pytorch implementation once they are ready.

Best, Shun Lu

3DJakob commented 1 year ago

@ShunLu91 Did you eventually release your pytorch code? I would be very interested in looking at it.

ShunLu91 commented 1 year ago

@3DJakob Sorry for the late reply and we have released our code just now. Our method proposed to explicitly minimize the gradient variance of the supernet training by jointly optimizing the sampling distributions of PAth and DAta, namely PA&DA. For the essential part of calculating the grad of each data, we finally take the reference from RHO-Loss and use it here. Welcome to try our code and discuss with us.

3DJakob commented 1 year ago

@ShunLu91 thank you for your response I will take a look at it 👍🏻