facebookresearch / higher

higher is a pytorch library allowing users to obtain higher order gradients over losses spanning training loops rather than individual training steps.
Apache License 2.0
1.59k stars 123 forks source link

Inner loop incompatible with weight_norm #14

Open AllanYangZhou opened 5 years ago

AllanYangZhou commented 5 years ago

Hi, thanks for your work on this library!

Using a weight normalized network in higher's inner loop raises the following error:

load from omniglot.npy.
DB: train (1200, 20, 1, 28, 28) test (423, 20, 1, 28, 28)
Traceback (most recent call last):
  File "maml-omniglot.py", line 271, in <module>
    main()
  File "maml-omniglot.py", line 108, in main
    train(db, net, device, meta_opt, epoch, log)
  File "maml-omniglot.py", line 146, in train
    spt_logits = fnet(x_spt[i])
  File "/iris/u/ayz/venvs/sharing_ws5/lib/python3.7/site-packages/torch/nn/modules/module.py", line 541, in __call__
    result = self.forward(*input, **kwargs)
  File "/iris/u/ayz/venvs/sharing_ws5/lib/python3.7/site-packages/higher/patch.py", line 347, in _patched_forward
    return self.boxed_forward(*args, **kwargs)
  File "/iris/u/ayz/venvs/sharing_ws5/lib/python3.7/site-packages/higher/patch.py", line 288, in patched_forward
    return true_forward(self, *args, **kwargs)
  File "/iris/u/ayz/venvs/sharing_ws5/lib/python3.7/site-packages/torch/nn/modules/container.py", line 92, in forward
    input = module(input)
  File "/iris/u/ayz/venvs/sharing_ws5/lib/python3.7/site-packages/torch/nn/modules/module.py", line 541, in __call__
    result = self.forward(*input, **kwargs)
  File "/iris/u/ayz/venvs/sharing_ws5/lib/python3.7/site-packages/higher/patch.py", line 288, in patched_forward
    return true_forward(self, *args, **kwargs)
  File "/iris/u/ayz/venvs/sharing_ws5/lib/python3.7/site-packages/torch/nn/modules/linear.py", line 87, in forward
    return F.linear(input, self.weight, self.bias)
  File "/iris/u/ayz/venvs/sharing_ws5/lib/python3.7/site-packages/torch/nn/functional.py", line 1370, in linear
    ret = torch.addmm(bias, input, weight.t())
RuntimeError: Expected object of device type cuda but got device type cpu for argument #3 'mat2' in call to _th_addmm

I can reproduce this by simply modifying the maml-omniglot example to weight_normalize the final linear layer (pasted below). The error only appears in the higher inner loop, I can evaluate the network on input data outside the inner loop with no error. I am on Ubuntu 16.0.4, Python 3.7.0, pytorch 1.3.0, and cuda 10.0.

#!/usr/bin/env python3
#
# Copyright (c) Facebook, Inc. and its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
This example shows how to use higher to do Model Agnostic Meta Learning (MAML)
for few-shot Omniglot classification.
For more details see the original MAML paper:
https://arxiv.org/abs/1703.03400

This code has been modified from Jackie Loong's PyTorch MAML implementation:
https://github.com/dragen1860/MAML-Pytorch/blob/master/omniglot_train.py
"""

import argparse
import time
import typing

import pandas as pd
import numpy as np
import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt
plt.style.use('bmh')

import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.nn.utils import weight_norm

import higher

from support.omniglot_loaders import OmniglotNShot

def main():
    argparser = argparse.ArgumentParser()
    argparser.add_argument('--n_way', type=int, help='n way', default=5)
    argparser.add_argument(
        '--k_spt', type=int, help='k shot for support set', default=5)
    argparser.add_argument(
        '--k_qry', type=int, help='k shot for query set', default=15)
    argparser.add_argument(
        '--task_num',
        type=int,
        help='meta batch size, namely task num',
        default=32)
    argparser.add_argument('--seed', type=int, help='random seed', default=1)
    args = argparser.parse_args()

    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)

    # Set up the Omniglot loader.
    device = torch.device('cuda')
    db = OmniglotNShot(
        '/tmp/omniglot-data',
        batchsz=args.task_num,
        n_way=args.n_way,
        k_shot=args.k_spt,
        k_query=args.k_qry,
        imgsz=28,
        device=device,
    )

    # Create a vanilla PyTorch neural network that will be
    # automatically monkey-patched by higher later.
    # Before higher, models could *not* be created like this
    # and the parameters needed to be manually updated and copied
    # for the updates.
    net = nn.Sequential(
        nn.Conv2d(1, 64, 3),
        nn.BatchNorm2d(64, momentum=1, affine=True),
        nn.ReLU(inplace=True),
        nn.MaxPool2d(2, 2),
        nn.Conv2d(64, 64, 3),
        nn.BatchNorm2d(64, momentum=1, affine=True),
        nn.ReLU(inplace=True),
        nn.MaxPool2d(2, 2),
        nn.Conv2d(64, 64, 3),
        nn.BatchNorm2d(64, momentum=1, affine=True),
        nn.ReLU(inplace=True),
        nn.MaxPool2d(2, 2),
        Flatten(),
        weight_norm(nn.Linear(64, args.n_way))).to(device)

    # We will use Adam to (meta-)optimize the initial parameters
    # to be adapted.
    meta_opt = optim.Adam(net.parameters(), lr=1e-3)

    log = []
    for epoch in range(100):
        train(db, net, device, meta_opt, epoch, log)
        test(db, net, device, epoch, log)
        plot(log)

def train(db, net, device, meta_opt, epoch, log):
    net.train()
    n_train_iter = db.x_train.shape[0] // db.batchsz

    for batch_idx in range(n_train_iter):
        start_time = time.time()
        # Sample a batch of support and query images and labels.
        x_spt, y_spt, x_qry, y_qry = db.next()

        task_num, setsz, c_, h, w = x_spt.size()
        querysz = x_qry.size(1)

        # TODO: Maybe pull this out into a separate module so it
        # doesn't have to be duplicated between `train` and `test`?

        # Initialize the inner optimizer to adapt the parameters to
        # the support set.
        n_inner_iter = 5
        inner_opt = torch.optim.SGD(net.parameters(), lr=1e-1)

        qry_losses = []
        qry_accs = []
        meta_opt.zero_grad()
        for i in range(task_num):
            with higher.innerloop_ctx(
                net, inner_opt, copy_initial_weights=False
            ) as (fnet, diffopt):
                # Optimize the likelihood of the support set by taking
                # gradient steps w.r.t. the model's parameters.
                # This adapts the model's meta-parameters to the task.
                # higher is able to automatically keep copies of
                # your network's parameters as they are being updated.
                for _ in range(n_inner_iter):
                    spt_logits = fnet(x_spt[i])
                    spt_loss = F.cross_entropy(spt_logits, y_spt[i])
                    diffopt.step(spt_loss)

                # The final set of adapted parameters will induce some
                # final loss and accuracy on the query dataset.
                # These will be used to update the model's meta-parameters.
                qry_logits = fnet(x_qry[i])
                qry_loss = F.cross_entropy(qry_logits, y_qry[i])
                qry_losses.append(qry_loss.detach())
                qry_acc = (qry_logits.argmax(
                    dim=1) == y_qry[i]).sum().item() / querysz
                qry_accs.append(qry_acc)

                # Update the model's meta-parameters to optimize the query
                # losses across all of the tasks sampled in this batch.
                # This unrolls through the gradient steps.
                qry_loss.backward()

        meta_opt.step()
        qry_losses = sum(qry_losses) / task_num
        qry_accs = 100. * sum(qry_accs) / task_num
        i = epoch + float(batch_idx) / n_train_iter
        iter_time = time.time() - start_time
        if batch_idx % 4 == 0:
            print(
                f'[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}'
            )

        log.append({
            'epoch': i,
            'loss': qry_losses,
            'acc': qry_accs,
            'mode': 'train',
            'time': time.time(),
        })

def test(db, net, device, epoch, log):
    # Crucially in our testing procedure here, we do *not* fine-tune
    # the model during testing for simplicity.
    # Most research papers using MAML for this task do an extra
    # stage of fine-tuning here that should be added if you are
    # adapting this code for research.
    net.train()
    n_test_iter = db.x_test.shape[0] // db.batchsz

    qry_losses = []
    qry_accs = []

    for batch_idx in range(n_test_iter):
        x_spt, y_spt, x_qry, y_qry = db.next('test')

        task_num, setsz, c_, h, w = x_spt.size()
        querysz = x_qry.size(1)

        # TODO: Maybe pull this out into a separate module so it
        # doesn't have to be duplicated between `train` and `test`?
        n_inner_iter = 5
        inner_opt = torch.optim.SGD(net.parameters(), lr=1e-1)

        for i in range(task_num):
            with higher.innerloop_ctx(net, inner_opt, track_higher_grads=False) as (fnet, diffopt):
                # Optimize the likelihood of the support set by taking
                # gradient steps w.r.t. the model's parameters.
                # This adapts the model's meta-parameters to the task.
                for _ in range(n_inner_iter):
                    spt_logits = fnet(x_spt[i])
                    spt_loss = F.cross_entropy(spt_logits, y_spt[i])
                    diffopt.step(spt_loss)

                # The query loss and acc induced by these parameters.
                qry_logits = fnet(x_qry[i]).detach()
                qry_loss = F.cross_entropy(
                    qry_logits, y_qry[i], reduction='none')
                qry_losses.append(qry_loss.detach())
                qry_accs.append(
                    (qry_logits.argmax(dim=1) == y_qry[i]).detach())

    qry_losses = torch.cat(qry_losses).mean().item()
    qry_accs = 100. * torch.cat(qry_accs).float().mean().item()
    print(
        f'[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}'
    )
    log.append({
        'epoch': epoch + 1,
        'loss': qry_losses,
        'acc': qry_accs,
        'mode': 'test',
        'time': time.time(),
    })

def plot(log):
    # Generally you should pull your plotting code out of your training
    # script but we are doing it here for brevity.
    df = pd.DataFrame(log)

    fig, ax = plt.subplots(figsize=(6, 4))
    train_df = df[df['mode'] == 'train']
    test_df = df[df['mode'] == 'test']
    ax.plot(train_df['epoch'], train_df['acc'], label='Train')
    ax.plot(test_df['epoch'], test_df['acc'], label='Test')
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Accuracy')
    ax.set_ylim(70, 100)
    fig.legend(ncol=2, loc='lower right')
    fig.tight_layout()
    fname = 'maml-accs.png'
    print(f'--- Plotting accuracy to {fname}')
    fig.savefig(fname)
    plt.close(fig)

# Won't need this after this PR is merged in:
# https://github.com/pytorch/pytorch/pull/22245
class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)

if __name__ == '__main__':
    main()
egrefen commented 4 years ago

Thanks for the detailed issue. I am on leave and returning early December. I will try to look into the issue as soon as I can get to it (I'll have a bit of a backlog but will try to see what I can do in the first few weeks of the month).

gcucurull commented 4 years ago

I get the same error if I use weight_norm and run the model on the GPU. If I use the CPU, then I get the following error:

Traceback (most recent call last):
  File "main.py", line 214, in <module>
    higher_train(opt, dataloader, generator, classifier, optimizer_a, optimizer_b)
  File "main.py", line 185, in higher_train
    real_loss.backward()
  File "/home/user/anaconda3/envs/torch1.3/lib/python3.7/site-packages/torch/tensor.py", line 166, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/home/user/anaconda3/envs/torch1.3/lib/python3.7/site-packages/torch/autograd/__init__.py", line 99, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
egrefen commented 4 years ago

Hello. I've returned from leave and allocated some time to look into this issue over the next two weeks. Hopefully we'll make some progress and report back, or come back to you with questions.

gcucurull commented 4 years ago

In case it helps:

Looking at the implementation of weight norm in Pytorch (here), the error could come because the module's weight attribute is set before every forward pass with a forward hook:

setattr(module, self.name, self.compute_weight(module))

However, because of the reparametrization of the weights, the weight attribute is not a parameter of the module, as it is replaced by module.weight_g and module.weight_v, so I'm not sure how higher is dealing with that.

Maybe the issue is with the function compute_weight that generates w from g and v, and higher isn't patching that.

egrefen commented 4 years ago

This does help. I've read through the weight_norm code from pytorch, and you are correct that this is something which higher isn't patching. We could write a hacky fix specifically for weight-norm, I think, but I would prefer a more general solution that caters to similar use cases. I will need to think through this properly and probably talk to some people from the pytorch team. I will attempt to look into this in the next two weeks, but it's going require some effort.

egrefen commented 4 years ago

As you might guess, COVID hit and this did not get looked into. I'll chase this when I have time but unfortunately time is a scarce resource :(

jeffwillette commented 3 years ago

I think this is also an issue for spectral_norm which uses the forward hooks as well (probably the exact same method but I haven't checked)...so a general solution would be awesome because there are probably other functions which do the same thing.

jeffwillette commented 3 years ago

If anyone is looking for a hack until this is fixed, I found that after doing a backward() I have to put an input through the model once before entering the higher loop and then it works.

_ = model(torch.rand_like(x))  # this goes after backward() and before higher
with higher.innerloop_ctx(...) as (fmodel, fopt):
    for (...) in inner loop:
       ...do inner loop

    loss += f model(x)

loss.backward()
opt.step()
jeffwillette commented 3 years ago

My hack above eventually broke due to something unrelated, so I took another look at this. I think this is a bug in the norm layers in Pytorch... They seem to make a dummy weight until the first forward pass overwrites it and the dummy weight doesn't get put on the right device with the rest of the model.

https://github.com/pytorch/pytorch/blob/master/torch/nn/utils/spectral_norm.py#L143

The above line is what I am talking about... but IDK where or when it eventually gets put on the GPU in the normal flow. Wherever that is, it seems to cause a mismatch with higher...

gene-chou commented 2 years ago

wondering if this has been fixed? @jeffwillette I tried your hack but got this error instead: RuntimeError: Function AddBackward0 returned an invalid gradient at index 1 - expected type TensorOptions(dtype=float, device=cpu, layout=Strided, requires_grad=false) but got TensorOptions(dtype=float, device=cuda:0, layout=Strided, requires_grad=false) (validate_outputs at /opt/conda/conda-bld/pytorch_1591914880026/work/torch/csrc/autograd/engine.cpp:484) do you know what might be the issue?

hiyamgh commented 1 year ago

Wondering if this has been fixed? Getting the following error:

RuntimeError: Function AddmmBackward returned an invalid gradient at index 0 - expected type TensorOptions(dtype=float, device=cpu, layout=Strided, requires_grad=false (default), pinned_memory=false (default), memory_format=(nullopt)) but got TensorOptions(dtype=float, device=cuda:0, layout=Strided, requires_grad=false (default), pinned_memory=false (default), memory_format=(nullopt))