I am implementing a multi-step optimization problem where I am using two models (visual_encoder resnet and a coefficient_vector) to calculate a weighted training loss. Backpropagating this loss leads to an update of my main model weights. Then, using validation loss I'd like to update my visual_encoder resnet and coefficient_vector parameters with higher.
The following function returns the gradients for those two models so that I can subsequently update them in another function.
However, when calling logits = fmodel(input) there's very high memory allocation that is never freed after returning the gradients from the function. Is there anything I am doing wrong here? Which reference is kept that I am missing? Any hint is highly appreciated and my apologies if this is not the right place to ask for this.
with higher.innerloop_ctx(model, optimizer) as (fmodel, foptimizer):
logits = fmodel(input)# heavy mempry allocation here which is never freed
weights = calc_instance_weights(input, target, input_val, target_val, logits, coefficient_vector, visual_encoder)#this returns a weight for each training sample (input)
weighted_training_loss = torch.mean(weights * F.cross_entropy(logits, target, reduction='none'))
foptimizer.step(weighted_training_loss) #update fmodel main model weights
logits = fmodel(input)
meta_val_loss = F.cross_entropy(logits, target)
coeff_vector_gradients = torch.autograd.grad(meta_val_loss, coefficient_vector, retain_graph=True) # get the gradients w.r.t. coefficient vector
coeff_vector_gradients = coeff_vector_gradients[0].detach()
visual_encoder_gradients = torch.autograd.grad(meta_val_loss,
visual_encoder.parameters())# get the gradients w.r.t. resnet parameters
visual_encoder_gradients = (visual_encoder_gradients[0].detach(), visual_encoder_gradients[1].detach())
return visual_encoder_gradients, coeff_vector_gradients
Hi, thanks so much for providing this library!
I am implementing a multi-step optimization problem where I am using two models (visual_encoder resnet and a coefficient_vector) to calculate a weighted training loss. Backpropagating this loss leads to an update of my main model weights. Then, using validation loss I'd like to update my visual_encoder resnet and coefficient_vector parameters with higher.
The following function returns the gradients for those two models so that I can subsequently update them in another function.
However, when calling logits = fmodel(input) there's very high memory allocation that is never freed after returning the gradients from the function. Is there anything I am doing wrong here? Which reference is kept that I am missing? Any hint is highly appreciated and my apologies if this is not the right place to ask for this.
Thanks a lot!