gbaydin / hypergradient-descent

Hypergradient descent
MIT License
136 stars 20 forks source link

Validation Loss for HD #2

Open nimamox opened 6 years ago

nimamox commented 6 years ago

Can we use validation loss instead of training loss to perform hypergradient descent in your implementation?

akaniklaus commented 5 years ago

@nimamox Yes, it is possible. From the paper: Note that we use the training objective, as opposed to the validation objective as in Maclaurin et al. (2015), for computing hypergradients. Modifications of HD computing gradients for both training and validation sets at each iteration and using the validation gradient only for updating α are possible, but not presented in this paper.

@gbaydin It would be great if you can provide an example implementation of this in this repo though.

akaniklaus commented 5 years ago

@nimamox I am trying to have an implementation that check-points the model and optimizer (without hypergradient descent) and then continue training the model with the validation loss and revert back to the checkpoint of which the learning-rate is updated according the hypergradient descent enabled optimizer that has been used during validation. I am curious if there is a nicer way of implementing this and would be glad to hear from @gbaydin on his suggested implementation for using validation loss.

gbaydin commented 5 years ago

@nimamox, @akaniklaus sorry to reply with such delay. It is possible to use validation loss in the hypergradient update step. The code in this repo doesn't support that, and if it did, it would break compatibility with the PyTorch optim API. Since our aim with this repo was to provide code that can be used as a drop-in replacement to PyTorch optim modules, I think it's unlikely that we will add this.

The needed code change would be simple. For example, in SGD code, one would just need to change the following line https://github.com/gbaydin/hypergradient-descent/blob/master/hypergrad/sgd_hd.py#L123 to use the gradients of the validation loss with respect to model parameters in the current and the previous iteration (the grad and grad_prev in the code).