vturrisi / solo-learn

solo-learn: a library of self-supervised methods for visual representation learning powered by Pytorch Lightning
MIT License
1.43k stars 186 forks source link

Can I use "torch.autograd.grad" in "training_step" funtion? #197

Closed chenshuang-zhang closed 2 years ago

chenshuang-zhang commented 2 years ago

Hello!

I hope to calculate the gradient of the loss on the model input every batch when training. The calculated gradients are then processed by some other functions and added into the loss function. The way I do this in pytorch is like this, with the torch.autograd.grad function:

model_input.requires_grad = True model.zero_grad() model.eval() grads = torch.autograd.grad(loss, model_input, grad_outputs=None, only_inputs=True, retain_graph=False)[0] model.train()

Can I add these codes directly into the training_step function in pytorch-lightning?

My concern is:

  1. Will this gradient calculation (torch.autograd.grad function) influence the accuracy of model training since I add it in the training_step function?
  2. Do I need to set model.zero_grad(), model.eval() and model.train() when calculating the gradients on the input?

Thank you very much!

vturrisi commented 2 years ago

Hey, I'm not super sure how torch.autograd.grad(...) works, but I would say that you can't backprop the model twice without retrain_graph=True. Also, you don't need to zero_grad() as there are no gradients in the model at this point. My suggestion would be to try and see what happens, as I'm not super sure what will happen.

Also, maybe this is interesting for you https://github.com/pytorch/functorch

chenshuang-zhang commented 2 years ago

Hey, I'm not super sure how torch.autograd.grad(...) works, but I would say that you can't backprop the model twice without retrain_graph=True. Also, you don't need to zero_grad() as there are no gradients in the model at this point. My suggestion would be to try and see what happens, as I'm not super sure what will happen.

Also, maybe this is interesting for you https://github.com/pytorch/functorch

@vturrisi Thank you very much for your help! I will try and see what happens!