facebookresearch / higher

higher is a pytorch library allowing users to obtain higher order gradients over losses spanning training loops rather than individual training steps.
Apache License 2.0
1.58k stars 123 forks source link

Memory not freed when moving out of scope? #105

Open jessicamecht opened 3 years ago

jessicamecht commented 3 years ago

Hi, thanks so much for providing this library!

I am implementing a multi-step optimization problem where I am using two models (visual_encoder resnet and a coefficient_vector) to calculate a weighted training loss. Backpropagating this loss leads to an update of my main model weights. Then, using validation loss I'd like to update my visual_encoder resnet and coefficient_vector parameters with higher.

The following function returns the gradients for those two models so that I can subsequently update them in another function.

However, when calling logits = fmodel(input) there's very high memory allocation that is never freed after returning the gradients from the function. Is there anything I am doing wrong here? Which reference is kept that I am missing? Any hint is highly appreciated and my apologies if this is not the right place to ask for this.

    with higher.innerloop_ctx(model, optimizer) as (fmodel, foptimizer):

        logits = fmodel(input)# heavy mempry allocation here which is never freed
        weights = calc_instance_weights(input, target, input_val, target_val, logits, coefficient_vector, visual_encoder)#this returns a weight for each training sample (input)
        weighted_training_loss = torch.mean(weights * F.cross_entropy(logits, target, reduction='none'))
        foptimizer.step(weighted_training_loss) #update fmodel main model weights

        logits = fmodel(input)
        meta_val_loss = F.cross_entropy(logits, target)

        coeff_vector_gradients = torch.autograd.grad(meta_val_loss, coefficient_vector, retain_graph=True) # get the gradients w.r.t. coefficient vector
        coeff_vector_gradients = coeff_vector_gradients[0].detach()
        visual_encoder_gradients = torch.autograd.grad(meta_val_loss,
                                                           visual_encoder.parameters())# get the gradients w.r.t. resnet parameters
        visual_encoder_gradients = (visual_encoder_gradients[0].detach(), visual_encoder_gradients[1].detach())

        return visual_encoder_gradients, coeff_vector_gradients

Thanks a lot!

brando90 commented 2 years ago

did you find a solution?

Did you try del fmodel?