Closed jkang1640 closed 1 year ago
After reading the code maml.py, I changed my code this way.
Could anybody tell me if it seems right..? Just one more thing, I have no idea when to use learner.module
or learner
. For example, which one should I return at the end of the code? Which one to pass on to maml_update
?
from torch import autograd
from learn2learn.algorithms import maml_update
def inner_loop(learner, support_dl, fast_lr, adaptation_steps):
all_loss = []
# initiliaze gradients to zero for all parameters
accum_grads = tuple(torch.zeros(param.size(), device=device) for param in learner.module.parameters())
for step in range(adaptation_steps):
loss = 0.0
for inputs in support_dl:
outputs = learner(**inputs)
accum_step = len(support_dl.dataset) / support_dl.batch_size # number of steps to accumulate gradients
curr_loss = outputs.loss / accum_step
curr_grads = autograd.grad(curr_loss, learner.module.parameters(), create_graph=False) # False for FOMAML
accum_grads = tuple(torch.add(accum_grads[i], curr_grads[i]) for i in range(len(curr_grads)))
loss += curr_loss.item() # accum loss for logging
del outputs, curr_loss, curr_grads
maml_update(learner, lr=fast_lr, grads=accum_grads) # update learner parameters with accumulated gradients
all_loss.append(loss)
return learner
Hello @jkang1640,
Sequence models are notoriously tricky to train with MAML as it's easy to run into OOM like you did.
Unfortunately there's no easy fix for when you're trying to accumulate gradients in the inner loop. My advice is to either pretrain and freeze some of the parameters of the model as in ANIL (to reduce memory consumption during meta-learning), or maybe go for a first-order algorithm (FOMAM, ANIL).
I'll close this for now.
Hello,
I am using learn2learn for seq2seq models and thank you very much for the awesome library!
Because of the gpu limit, I'd like to use gradint accumulation. For the outer loop, there is no problem implementing this but inside innder loop, I don't know how to do this efficiently since
adapt
step includes both backward and update step.Here's my code blow. This is not efficient way at all since I keep accumulating
outputs.loss
instead of accumulating gradients only. Therefore, computation graphs are not released till I calllearner.adapt(loss)
.This causes OOM when I run my code. Is there any way to apply gradient accumulation efficiently for the inner loop? Thank you very much for your help!