learnables / learn2learn

A PyTorch Library for Meta-learning Research
http://learn2learn.net
MIT License
2.68k stars 354 forks source link

Gradient accumulation in the inner loop #394

Closed jkang1640 closed 1 year ago

jkang1640 commented 1 year ago

Hello,

I am using learn2learn for seq2seq models and thank you very much for the awesome library!

Because of the gpu limit, I'd like to use gradint accumulation. For the outer loop, there is no problem implementing this but inside innder loop, I don't know how to do this efficiently since adapt step includes both backward and update step.

Here's my code blow. This is not efficient way at all since I keep accumulating outputs.loss instead of accumulating gradients only. Therefore, computation graphs are not released till I call learner.adapt(loss).

This causes OOM when I run my code. Is there any way to apply gradient accumulation efficiently for the inner loop? Thank you very much for your help!

    def inner_loop(learner, support_dl, adaptation_steps):

        total_loss = []
        for step in range(adaptation_steps):
            loss = 0.0

            for inputs in support_dl:
                outputs = learner(**inputs)
                curr_loss = outputs.loss
                loss += (curr_loss * support_dl.batch_size / len(support_dl.dataset))

            learner.adapt(loss)
            total_loss.append(loss.item())

        return learner, sum(total_loss) / len(total_loss)
jkang1640 commented 1 year ago

After reading the code maml.py, I changed my code this way.

Could anybody tell me if it seems right..? Just one more thing, I have no idea when to use learner.module or learner. For example, which one should I return at the end of the code? Which one to pass on to maml_update?

    from torch import autograd 
    from learn2learn.algorithms import maml_update

    def inner_loop(learner, support_dl, fast_lr, adaptation_steps):

        all_loss = []
        # initiliaze gradients to zero for all parameters
        accum_grads = tuple(torch.zeros(param.size(), device=device) for param in learner.module.parameters())

        for step in range(adaptation_steps):
            loss = 0.0

            for inputs in support_dl:
                outputs = learner(**inputs)
                accum_step = len(support_dl.dataset) / support_dl.batch_size # number of steps to accumulate gradients 
                curr_loss = outputs.loss / accum_step
                curr_grads = autograd.grad(curr_loss, learner.module.parameters(), create_graph=False)  # False for FOMAML
                accum_grads = tuple(torch.add(accum_grads[i], curr_grads[i]) for i in range(len(curr_grads)))
                loss += curr_loss.item() # accum loss for logging
                del outputs, curr_loss, curr_grads

            maml_update(learner, lr=fast_lr, grads=accum_grads) # update learner parameters with accumulated gradients
            all_loss.append(loss)

      return learner 
seba-1511 commented 1 year ago

Hello @jkang1640,

Sequence models are notoriously tricky to train with MAML as it's easy to run into OOM like you did.

Unfortunately there's no easy fix for when you're trying to accumulate gradients in the inner loop. My advice is to either pretrain and freeze some of the parameters of the model as in ANIL (to reduce memory consumption during meta-learning), or maybe go for a first-order algorithm (FOMAM, ANIL).

I'll close this for now.