learnables / learn2learn

A PyTorch Library for Meta-learning Research
http://learn2learn.net
MIT License
2.66k stars 353 forks source link

Error from maml, gradients var not found -- why? #314

Closed brando90 closed 2 years ago

brando90 commented 2 years ago

I got this error:

Traceback (most recent call last):
  File "/home/miranda9/miniconda3/envs/meta_learning_a100/lib/python3.9/site-packages/learn2learn/algorithms/maml.py", line 159, in adapt
    gradients = grad(loss,
  File "/home/miranda9/miniconda3/envs/meta_learning_a100/lib/python3.9/site-packages/torch/autograd/__init__.py", line 226, in grad
    return Variable._execution_engine.run_backward(
RuntimeError: One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior.
Traceback (most recent call last):
  File "/home/miranda9/diversity-for-predictive-success-of-meta-learning/div_src/diversity_src/experiment_mains/main_dist_maml_l2l.py", line 627, in <module>
    main()
  File "/home/miranda9/diversity-for-predictive-success-of-meta-learning/div_src/diversity_src/experiment_mains/main_dist_maml_l2l.py", line 563, in main
    train(args=args)
  File "/home/miranda9/diversity-for-predictive-success-of-meta-learning/div_src/diversity_src/experiment_mains/main_dist_maml_l2l.py", line 607, in train
    meta_train_iterations_ala_l2l(args, args.agent, args.opt, args.scheduler)
  File "/home/miranda9/ultimate-utils/ultimate-utils-proj-src/uutils/torch_uu/training/meta_training.py", line 149, in meta_train_iterations_ala_l2l
    train_loss, train_loss_std, train_acc, train_acc_std = meta_learner(task_dataset, call_backward=True)
  File "/home/miranda9/miniconda3/envs/meta_learning_a100/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/miranda9/ultimate-utils/ultimate-utils-proj-src/uutils/torch_uu/meta_learners/maml_meta_learner.py", line 375, in forward
    meta_loss, meta_loss_ci, meta_acc, meta_acc_ci = forward(meta_learner=self,
  File "/home/miranda9/ultimate-utils/ultimate-utils-proj-src/uutils/torch_uu/meta_learners/maml_meta_learner.py", line 318, in forward
    loss, acc = fast_adapt(
  File "/home/miranda9/ultimate-utils/ultimate-utils-proj-src/uutils/torch_uu/meta_learners/maml_meta_learner.py", line 272, in fast_adapt
    learner.adapt(adaptation_error)
  File "/home/miranda9/miniconda3/envs/meta_learning_a100/lib/python3.9/site-packages/learn2learn/algorithms/maml.py", line 169, in adapt
    self.module = maml_update(self.module, self.lr, gradients)
UnboundLocalError: local variable 'gradients' referenced before assignmen

which I thought was puzzling. I am not doing anything fancy but copying most of the code from the tutorials...so I don't know why I'd get this error. Any ideas?

related: https://github.com/pytorch/pytorch/issues/73697

brando90 commented 2 years ago

btw, definitively don't want this:

learn2learn: Maybe try with allow_nograd=True and/or allow_unused=True ?

its just a 5cnn on mini-imagenet/cifarfs. There are no unused params...

brando90 commented 2 years ago

related: https://discuss.pytorch.org/t/runtimeerror-one-of-the-differentiated-tensors-appears-to-not-have-been-used-in-the-graph-set-allow-unused-true-if-this-is-the-desired-behavior/43679/9

https://stackoverflow.com/questions/71271069/why-cant-pytorch-find-all-the-params-to-do-gradient-descent-with-maml-and-the-l

seba-1511 commented 2 years ago

Could you share a small example on colab? It's hard to help without knowing what you're doing.

brando90 commented 2 years ago

@seba-1511 how do I set learn2learn: Maybe try with allow_nograd=True and/or allow_unused=True ??

brando90 commented 2 years ago

for learn to learn this solved my issue:

Solution:

current solution for me is to:

code:

#!/usr/bin/env python3

"""
Demonstrates how to:
    * use the MAML wrapper for fast-adaptation,
    * use the benchmark interface to load mini-ImageNet, and
    * sample tasks and split them in adaptation and evaluation sets.
To contrast the use of the benchmark interface with directly instantiating mini-ImageNet datasets and tasks, compare with `protonet_miniimagenet.py`.
"""

import random
import numpy as np

import torch
from torch import nn, optim

import learn2learn as l2l
from learn2learn.data.transforms import (NWays,
                                         KShots,
                                         LoadData,
                                         RemapLabels,
                                         ConsecutiveLabels)

def accuracy(predictions, targets):
    predictions = predictions.argmax(dim=1).view(targets.shape)
    return (predictions == targets).sum().float() / targets.size(0)

def fast_adapt(batch, learner, loss, adaptation_steps, shots, ways, device):
    data, labels = batch
    data, labels = data.to(device), labels.to(device)

    # Separate data into adaptation/evalutation sets
    adaptation_indices = np.zeros(data.size(0), dtype=bool)
    adaptation_indices[np.arange(shots * ways) * 2] = True
    evaluation_indices = torch.from_numpy(~adaptation_indices)
    adaptation_indices = torch.from_numpy(adaptation_indices)
    adaptation_data, adaptation_labels = data[adaptation_indices], labels[adaptation_indices]
    evaluation_data, evaluation_labels = data[evaluation_indices], labels[evaluation_indices]

    # Adapt the model
    for step in range(adaptation_steps):
        adaptation_error = loss(learner(adaptation_data), adaptation_labels)
        learner.adapt(adaptation_error)

    # Evaluate the adapted model
    predictions = learner(evaluation_data)
    evaluation_error = loss(predictions, evaluation_labels)
    evaluation_accuracy = accuracy(predictions, evaluation_labels)
    return evaluation_error, evaluation_accuracy

def main(
        ways=5,
        shots=5,
        meta_lr=0.003,
        fast_lr=0.5,
        meta_batch_size=32,
        adaptation_steps=1,
        num_iterations=60000,
        cuda=True,
        seed=42,
):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    device = torch.device('cpu')
    if cuda and torch.cuda.device_count():
        torch.cuda.manual_seed(seed)
        device = torch.device('cuda')

    # Create Tasksets using the benchmark interface
    tasksets = l2l.vision.benchmarks.get_tasksets('mini-imagenet',
                                                  train_samples=2 * shots,
                                                  train_ways=ways,
                                                  test_samples=2 * shots,
                                                  test_ways=ways,
                                                  root='~/data/l2l_data/',
                                                  )

    # Create model
    # model = l2l.vision.models.MiniImagenetCNN(ways)
    from uutils.torch_uu.models.hf_uu.vit_uu import get_vit_get_vit_model_and_model_hps_mi
    model, _ = get_vit_get_vit_model_and_model_hps_mi()
    model.to(device)
    maml = l2l.algorithms.MAML(model, lr=fast_lr, first_order=False, allow_unused=True)
    opt = optim.Adam(maml.parameters(), meta_lr)
    loss = nn.CrossEntropyLoss(reduction='mean')

    for iteration in range(num_iterations):
        opt.zero_grad()
        meta_train_error = 0.0
        meta_train_accuracy = 0.0
        meta_valid_error = 0.0
        meta_valid_accuracy = 0.0
        for task in range(meta_batch_size):
            # Compute meta-training loss
            learner = maml.clone()
            batch = tasksets.train.sample()
            evaluation_error, evaluation_accuracy = fast_adapt(batch,
                                                               learner,
                                                               loss,
                                                               adaptation_steps,
                                                               shots,
                                                               ways,
                                                               device)
            evaluation_error.backward()
            meta_train_error += evaluation_error.item()
            meta_train_accuracy += evaluation_accuracy.item()

            # Compute meta-validation loss
            learner = maml.clone()
            batch = tasksets.validation.sample()
            evaluation_error, evaluation_accuracy = fast_adapt(batch,
                                                               learner,
                                                               loss,
                                                               adaptation_steps,
                                                               shots,
                                                               ways,
                                                               device)
            meta_valid_error += evaluation_error.item()
            meta_valid_accuracy += evaluation_accuracy.item()

        # Print some metrics
        print('\n')
        print('Iteration', iteration)
        print('Meta Train Error', meta_train_error / meta_batch_size)
        print('Meta Train Accuracy', meta_train_accuracy / meta_batch_size)
        print('Meta Valid Error', meta_valid_error / meta_batch_size)
        print('Meta Valid Accuracy', meta_valid_accuracy / meta_batch_size)

        # Average the accumulated gradients and optimize
        for p in maml.parameters():
            if p.grad is not None:
                p.grad.data.mul_(1.0 / meta_batch_size)
        opt.step()

    meta_test_error = 0.0
    meta_test_accuracy = 0.0
    for task in range(meta_batch_size):
        # Compute meta-testing loss
        learner = maml.clone()
        batch = tasksets.test.sample()
        evaluation_error, evaluation_accuracy = fast_adapt(batch,
                                                           learner,
                                                           loss,
                                                           adaptation_steps,
                                                           shots,
                                                           ways,
                                                           device)
        meta_test_error += evaluation_error.item()
        meta_test_accuracy += evaluation_accuracy.item()
    print('Meta Test Error', meta_test_error / meta_batch_size)
    print('Meta Test Accuracy', meta_test_accuracy / meta_batch_size)

if __name__ == '__main__':
    """
python ~/ultimate-utils/tutorials_for_myself/my_l2l/serial_maml_l2l_hf_vit_simple.py
    """
    main()