facebookresearch / higher

higher is a pytorch library allowing users to obtain higher order gradients over losses spanning training loops rather than individual training steps.
Apache License 2.0
1.59k stars 123 forks source link

Question about step execution time #20

Open AntoineHX opened 4 years ago

AntoineHX commented 4 years ago

Hi,

I was wondering if it was intended that the diffopt.step(loss) command takes an increasing time to execute when the number of state saved is increasing ? The step time should be constant as we perform a back-propagation on only one state, while the computation of a meta-gradient, at the end of the inner loops, should be longer as the number of states saved increase, right ?

t = time.process_time()
diffopt.step(loss) #(opt.zero_grad, loss.backward, opt.step)
print(len(fmodel._fast_params),"step", time.process_time()-t)

(...)

t = time.process_time()
val_loss.backward()
print("meta", time.process_time()-t)

This code gives me :

2 step 0.513532734
3 step 1.2003996619999988
4 step 1.4545214800000004
5 step 1.6974909480000004
6 step 1.9400910080000013
7 step 2.1659202289999975
meta 1.4035290689999975
2 step 0.4082054819999996
3 step 1.236462700999997
4 step 1.4650358509999997
5 step 1.702944763999998
6 step 1.9239114150000027
meta 1.1481397730000005

EDIT :

After looking into the DifferentiableOptimizer code, it's seems that what's causing this slow down is the building of the graph of the derivative in

all_grads = _torch.autograd.grad(
            loss,
            grad_targets,
            create_graph=self._track_higher_grads,
            allow_unused=True  # boo
 )

I'm not really familiar with the way autograd handle this but it seems the whole graph is computed at each call. Isn't it possible to keep the previous graph and extend it as the gradient tape expand, with the new states ?

egrefen commented 4 years ago

@AntoineHX This should not be the case. Would you please provide a minimal repro (small self-contained bit of code) which reproduces the behaviour above? We will investigate ASAP.

AntoineHX commented 4 years ago

@egrefen That's good to know. The issue could be on my side then. It seems to happens when i use a functionnal model built from a model wrapped inside a class with parameter (on which i wish to get a gradient).

Here's the repro :

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import higher
import time

data_train = torchvision.datasets.CIFAR10("./data", train=True, download=True, transform=torchvision.transforms.ToTensor())
dl_train = torch.utils.data.DataLoader(data_train, batch_size=300, shuffle=True, num_workers=0, pin_memory=False)

class Aug_model(nn.Module):
    def __init__(self, model, hyper_param=True):
        super(Aug_model, self).__init__()

        #### Origin of the issue ? ####
        if hyper_param:
            self._params = nn.ParameterDict({
                    "hyper_param": nn.Parameter(torch.Tensor([0.5])),
                })
        ###############################

        self._mods = nn.ModuleDict({
            'model': model,
            })

    def forward(self, x):
        return self._mods['model'](x) #* self._params['hyper_param']

    def __getitem__(self, key):
        return self._mods[key]

if __name__ == "__main__":

    device = torch.device('cuda:1')
    aug_model = Aug_model(
                    model=torch.hub.load('pytorch/vision:v0.4.2', 'resnet18', pretrained=False),
                    hyper_param=True #False will not extend step time
                    ).to(device)

    inner_opt = torch.optim.SGD(aug_model['model'].parameters(), lr=1e-2, momentum=0.9)

    fmodel = higher.patch.monkeypatch(aug_model, device=None, copy_initial_weights=True)
    diffopt = higher.optim.get_diff_optim(inner_opt, aug_model.parameters(),fmodel=fmodel,track_higher_grads=True)

    for i, (xs, ys) in enumerate(dl_train):
        xs, ys = xs.to(device), ys.to(device)

        logits = fmodel(xs)
        loss = F.cross_entropy(F.log_softmax(logits, dim=1), ys, reduction='mean')

        t = time.process_time()
        diffopt.step(loss) #(opt.zero_grad, loss.backward, opt.step)
        print(len(fmodel._fast_params),"step", time.process_time()-t)
egrefen commented 4 years ago

It's entirely possible the problem is not your fault. Either way, this is not normal and we'll get to the bottom of this. To manage expectations, I plan to take a look at this on Monday next week, but please do let me know if this is somehow more time-sensitive.

AntoineHX commented 4 years ago

Do you have any lead ? Tell me if i can help you :smiley:

egrefen commented 4 years ago

I can repro your issue on CPU using the script you provided. Thanks for narrowing it down to the step where the higher order graph is created. Running your code with track_higher_grads=False in the diffopt creation yields a constant step time.

I'm not entirely sure why this is happening: the backward graph is bigger I suppose, but it should already be partly created. I'm not sure what is causing a bigger graph to be created from scratch. I'll have a think about this later this week, but I think we'd need to narrow it down a bit and/or try to figure out what the simplest repro of this is that doesn't have deep higher dependencies, so I can go to the pytorch team for help if needed.

egrefen commented 4 years ago

Sorry for the delay in looking into this. I believe this problem is linked to something in core pytorch, as flagged in this issue: https://github.com/pytorch/pytorch/issues/12635. Someone is looking into it AFAIK, so I'll report back if progress is made there.

egrefen commented 3 years ago

I'm trying to reproduce this in a clean, minimal way, and am struggling a bit:

import torch
import torch.nn.functional as F
import torchvision
import higher
import time
import numpy as np

import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--no-higher', action="store_true",
                    help="Dont use higher on inner loop.")
parser.add_argument('--no-unused', action="store_true",
                    help="Don't introduce unused hyperparameter.")
parser.add_argument('-v', '--verbose', action="store_true",
                    help="Print times.")

class Aug_model(torch.nn.Module):
    def __init__(self, model, hyper_param=True):
        super(Aug_model, self).__init__()

        #### Origin of the issue ? ####
        if hyper_param:
            self._params = torch.nn.ParameterDict({
                    "hyper_param": torch.nn.Parameter(torch.Tensor([0.5])),
                })
        ###############################

        self._mods = torch.nn.ModuleDict({
            'model': model,
            })

    def forward(self, x):
        return self._mods['model'](x) #* self._params['hyper_param']

    def __getitem__(self, key):
        return self._mods[key]

def main(args):
    print(args)

    aug_model = Aug_model(
                    model=torch.hub.load('pytorch/vision:v0.4.2', 'resnet18', pretrained=False),
                    hyper_param=not args.no_unused #False will not extend step time
                    )
    inner_opt = torch.optim.SGD(aug_model['model'].parameters(), lr=1e-2)

    if not args.no_higher:
        fmodel = higher.patch.monkeypatch(aug_model, copy_initial_weights=False)
        diffopt = higher.optim.get_diff_optim(inner_opt, aug_model.parameters(),fmodel=fmodel)

    times = []

    for i in range(50):

        # Fake data
        xs = torch.rand(300, 3, 32, 32)
        ys = torch.ones((300,), dtype=torch.long)

        # Forward step
        if args.no_higher:
            logits = aug_model(xs)
            loss = F.cross_entropy(F.log_softmax(logits, dim=1), ys, reduction='mean')
            loss.backward()
        else:
            logits = fmodel(xs)
            loss = F.cross_entropy(F.log_softmax(logits, dim=1), ys, reduction='mean')

        t = time.process_time()

        # Optimizer update
        if args.no_higher:
            inner_opt.step()
            inner_opt.zero_grad()
        else:
            diffopt.step(loss) #(opt.zero_grad, loss.backward, opt.step)

        time_of_step = time.process_time()-t

        if args.verbose:
            print("Time of step {}: {:.2f}ms.".format(i, time_of_step*1000))
        times.append(time_of_step)

    times = np.array(times)
    end_vs_start = times[-30:].mean()/times[:30].mean()
    print("Last 10 steps avg was {:.2f}% of first 10 steps.".format(end_vs_start*100))

if __name__ == "__main__":
    args = parser.parse_args()
    main(args)

The above code has quite a lot of overlap with yours, and yet doesn't exhibit the increasing time per step. However, the code I submitted to https://github.com/pytorch/pytorch/pull/52180 as a higher-free minimum working example does, so I'm scratching my head...