learnables / learn2learn

A PyTorch Library for Meta-learning Research
http://learn2learn.net
MIT License
2.59k stars 348 forks source link

MAML l2l not giving gradients? why? UnboundLocalError: local variable 'gradients' referenced before assignment #387

Closed brando90 closed 1 year ago

brando90 commented 1 year ago

I have 4 tensors spt_x, spt_y, qry_x, qry_y that I am concatenating. Then I am putting it in a dummy TaskDatset object so that L2L works but the following error:

Error:

Traceback (most recent call last):
  File "/lfs/ampere4/0/brando9/miniconda/envs/mds_env_gpu/lib/python3.9/site-packages/learn2learn/algorithms/maml.py", line 159, in adapt
    gradients = grad(loss,
  File "/lfs/ampere4/0/brando9/miniconda/envs/mds_env_gpu/lib/python3.9/site-packages/torch/autograd/__init__.py", line 226, in grad
    return Variable._execution_engine.run_backward(
RuntimeError: One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior.
learn2learn: Maybe try with allow_nograd=True and/or allow_unused=True ?
Traceback (most recent call last):
  File "/lfs/ampere4/0/brando9/diversity-for-predictive-success-of-meta-learning/div_src/diversity_src/experiment_mains/main_maml_torchmeta.py", line 427, in <module>
    # - run experiment
  File "/lfs/ampere4/0/brando9/diversity-for-predictive-success-of-meta-learning/div_src/diversity_src/experiment_mains/main_maml_torchmeta.py", line 359, in main
    train(rank=-1, args=args)
  File "/lfs/ampere4/0/brando9/diversity-for-predictive-success-of-meta-learning/div_src/diversity_src/experiment_mains/main_maml_torchmeta.py", line 403, in train
    elif 'iterations' in args.training_mode:
  File "/afs/cs.stanford.edu/u/brando9/ultimate-utils/ultimate-utils-proj-src/uutils/torch_uu/training/meta_training.py", line 110, in meta_train_fixed_iterations
    log_zeroth_step(args, meta_learner)
  File "/afs/cs.stanford.edu/u/brando9/ultimate-utils/ultimate-utils-proj-src/uutils/logging_uu/wandb_logging/supervised_learning.py", line 170, in log_zeroth_step
    train_loss, train_acc = model(batch, training=training)
  File "/lfs/ampere4/0/brando9/miniconda/envs/mds_env_gpu/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/afs/cs.stanford.edu/u/brando9/ultimate-utils/ultimate-utils-proj-src/uutils/torch_uu/meta_learners/maml_meta_learner.py", line 276, in forward
    meta_batch_size: int = max(self.args.batch_size // self.args.world_size, 1)
  File "/afs/cs.stanford.edu/u/brando9/ultimate-utils/ultimate-utils-proj-src/uutils/torch_uu/meta_learners/maml_meta_learner.py", line 174, in forward
    # - adapt
  File "/afs/cs.stanford.edu/u/brando9/ultimate-utils/ultimate-utils-proj-src/uutils/torch_uu/meta_learners/maml_meta_learner.py", line 213, in get_lists_accs_losses_l2l
    learner = meta_learner.maml.clone()
  File "/afs/cs.stanford.edu/u/brando9/ultimate-utils/ultimate-utils-proj-src/uutils/torch_uu/meta_learners/maml_meta_learner.py", line 128, in fast_adapt
    (support_data, support_labels), (query_data, query_labels) = learn2learn.data.partition_task(
  File "/lfs/ampere4/0/brando9/miniconda/envs/mds_env_gpu/lib/python3.9/site-packages/learn2learn/algorithms/maml.py", line 169, in adapt
    self.module = maml_update(self.module, self.lr, gradients)
UnboundLocalError: local variable 'gradients' referenced before assignment

Why?

I don't think my data should affect how the maml meta-learner behaves?

class EpisodicBatchAsTaskDataset(TaskDataset):

    def __init__(self, batch: Tensor):
        self.batch = batch
        print(f'{len(self.batch)=}')  # e.g. 4 [spt_x, spt_y, qry_x, qry_y]
        assert len(batch) == 4, f'Error: Expected 4 tensors in batch because we have 4 [spt_x, spt_y, qry_x, qry_y] ' \
                                f'but got {len(batch)}.'
        print(f'{batch[0].size()=}')  # e.g. for vision it should be [B, n*k, C, H, W]
        print(f'{batch[1].size()=}')  # e.g. for vision it should be [B, n*k]
        self.idx = 0

    def sample(self, idx: Optional[int] = None,
               ) -> list[Tensor, Tensor]:
        """
        Gets a single task from the batch of tasks.

        the l2l forward pass is as dollows:
            meta_losses, meta_accs = [], []
            for task in range(meta_batch_size):
                # print(f'{task=}')
                # - Sample all data data for spt & qry sets for current task: thus size [n*(k+k_eval), C, H, W] (or [n(k+k_eval), D])
                task_data: list = task_dataset.sample()  # data, labels

                # -- Inner Loop Adaptation
                learner = meta_learner.maml.clone()
                loss, acc = fast_adapt(
                    args=args,
                    task_data=task_data,
                    learner=learner,
                    loss=args.loss,
                    adaptation_steps=meta_learner.nb_inner_train_steps,
                    shots=args.k_shots,
                    ways=args.n_classes,
                    device=args.device,
                )
        therefore, we need to concatenate the tasks in the right dimension and return it. The foward pass then splits it
        according to the shots and ways on its own.
        """
        if idx is None:
            idx = self.idx
        else:
            self.idx = idx

        # - want x, y to be of shape: [n*(k+k_eval), C, H, W] (or [n(k+k_eval), D])
        spt_x, spt_y, qry_x, qry_y = self.batch
        spt_x, spt_y, qry_x, qry_y = spt_x[idx], spt_y[idx], qry_x[idx], qry_y[idx]
        # - concatenate spt_x, qry_x & spt_y, qry_y
        x = torch.cat([spt_x, qry_x], dim=0)
        y = torch.cat([spt_y, qry_y], dim=0)
        print(f'{x.size()=}')
        print(f'{y.size()=}')
        task_data: list = [x, y]
        self.idx += 1
        return task_data
brando90 commented 1 year ago

@seba-1511 idk why but this didn't display in my terminal but it DID display in pycharm:

learn2learn: Maybe try with allow_nograd=True and/or allow_unused=True ?
seba-1511 commented 1 year ago

Correct, I believe these flags should fix your issue. Closing.