learnables / learn2learn

A PyTorch Library for Meta-learning Research
http://learn2learn.net
MIT License
2.61k stars 350 forks source link

Is LearnableOptimizer compatible with MAML? #381

Closed jkang1640 closed 1 year ago

jkang1640 commented 1 year ago

I'm trying to use Learnable Optimizer for MAML and on the second iteration by the end of second task batch, I always get the following error

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

Here's my code .. could anybody help me to find why I have this error?? Thank you very much..!


     # Create model
    model_name = "distilbert-base-multilingual-cased" # distil bert only for debugging purposes
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=ways, ignore_mismatched_sizes=True)
    model.to(device)
    maml = l2l.algorithms.MAML(model, lr=fast_lr) # or MetaSGD
    metaopt = l2l.optim.LearnableOptimizer(
        model=maml,  # We pass the model, not its parameters
        transform=HypergradTransform,  # Any transform could work
        lr=3e-5)
    metaopt.to(device)
    opt = optim.SGD(metaopt.parameters(), meta_lr)
    loss = nn.CrossEntropyLoss(reduction='mean')

    for iteration in range(1, num_iterations+1):
        opt.zero_grad()
        metaopt.zero_grad()
        meta_train_error = 0.0
        meta_train_accuracy = 0.0

        for task in range(meta_batch_size):
            learner = maml.clone()
            x, y = next(train_gen_list).sample()
            support, query = partition_task_text(x, y, shots=shots)
            evaluation_error, evaluation_accuracy = fast_adapt(support,
                                                               query,
                                                               learner,
                                                               tokenizer,
                                                               loss,
                                                               adaptation_steps,
                                                               device,
                                                               batch_size=inner_loop_batch_size)
            evaluation_error.backward()
            meta_train_error += evaluation_error.item()
            meta_train_accuracy += evaluation_accuracy

        # Average the accumulated gradients and optimize
        for p in maml.parameters():
            p.grad.data.mul_(1.0 / meta_batch_size)
        opt.step()
        metaopt.step()
seba-1511 commented 1 year ago

Hello @jkang1640,

Could you create a short colab with a small linear regression task that reproduces the bug?

Learned optimizers should work with MAML in principle. But there are always subtleties with nested gradients-of-gradients, so having a small failure case helps debugging.

jkang1640 commented 1 year ago

Hello @seba-1511

Here's the link to the colab : https://colab.research.google.com/drive/1aHeIXX1whCZLaNe2XmyzElJfg3zd0aNG?usp=sharing

Basically I'd like to to do multilngual text classification with MAML (few-shot to an unseen language).

I changed my code a little so that I use MAML only for a linear regression (by using embeddings from a frozen LM instead of finetuning it) - I hope that this does the job you asked for a small linear regression task.

You can see that it reproduces the error that I mentioned earlier in the post.. Could you check the bug and help me to figure out where this comes from?

Thank you very much for your help

seba-1511 commented 1 year ago

Thanks for the colab @jkang1640.

I quickly played with it, and it seems like changing l.66 in the last cell to evaluation_error.backward(create_graph=True) fixes the problem. Could you check if that solves it on your end too?

Edit: scratch that, I'm getting a memory-leak now.

jkang1640 commented 1 year ago

Thank you very much @seba-1511. Yes, I also have a memory-leak but at least it solved the first problem, which was not obvious at all.. You've just save my day! Thank you so much.

I'll have to investigate more about why I had this error but do you have any intuition why it threw that error in the first place and why during the second iteration and not during the first iteration?

One more question - do you think it should be create_graph=True instead of retain_graph=True? I tried retain_graph=True and it seems to fix the memory-leak but I'm still scracthing my head about which solution is right.

Thank you

seba-1511 commented 1 year ago

I tried something else: keep create_graph=False on l.66 and comment learner.eval() if eval else learner.train() in fast_adapt(). This works well, but you need to decrease the batch sizes because it runs out of memory (but no memory leak). See my copy of your colab.

Also: I believe you need to divide the gradients of the meta-optimizer in l.72 of the last cell (since they're accumulated for each meta-batch).

Regarding retain_graph: retain_graph is if you want to back-prop the same computation graph again, create_graph is if you want to compute the gradients of the graph. In this case we don't need either hence the above solution.

jkang1640 commented 1 year ago

I really appreicate your quick answer and suggested solutions.

I'm afraid your second solution does not really solve the problem though and it seems that meta_batch_size=1 did the trick. When I increase it to 3, it does not work and throws the same error :(

seba-1511 commented 1 year ago

Right, this is unexpected. I wonder if it isn't due to extracting features with the LM first but don't have time to look into this now.

My recommendation is keeping retain_graph=True since that doesn't run out of memory (and the graph will be cleared once variables get out of scope).

seba-1511 commented 1 year ago

Closing since inactive.