Open AntoineHX opened 4 years ago
@AntoineHX This should not be the case. Would you please provide a minimal repro (small self-contained bit of code) which reproduces the behaviour above? We will investigate ASAP.
@egrefen That's good to know. The issue could be on my side then. It seems to happens when i use a functionnal model built from a model wrapped inside a class with parameter (on which i wish to get a gradient).
Here's the repro :
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import higher
import time
data_train = torchvision.datasets.CIFAR10("./data", train=True, download=True, transform=torchvision.transforms.ToTensor())
dl_train = torch.utils.data.DataLoader(data_train, batch_size=300, shuffle=True, num_workers=0, pin_memory=False)
class Aug_model(nn.Module):
def __init__(self, model, hyper_param=True):
super(Aug_model, self).__init__()
#### Origin of the issue ? ####
if hyper_param:
self._params = nn.ParameterDict({
"hyper_param": nn.Parameter(torch.Tensor([0.5])),
})
###############################
self._mods = nn.ModuleDict({
'model': model,
})
def forward(self, x):
return self._mods['model'](x) #* self._params['hyper_param']
def __getitem__(self, key):
return self._mods[key]
if __name__ == "__main__":
device = torch.device('cuda:1')
aug_model = Aug_model(
model=torch.hub.load('pytorch/vision:v0.4.2', 'resnet18', pretrained=False),
hyper_param=True #False will not extend step time
).to(device)
inner_opt = torch.optim.SGD(aug_model['model'].parameters(), lr=1e-2, momentum=0.9)
fmodel = higher.patch.monkeypatch(aug_model, device=None, copy_initial_weights=True)
diffopt = higher.optim.get_diff_optim(inner_opt, aug_model.parameters(),fmodel=fmodel,track_higher_grads=True)
for i, (xs, ys) in enumerate(dl_train):
xs, ys = xs.to(device), ys.to(device)
logits = fmodel(xs)
loss = F.cross_entropy(F.log_softmax(logits, dim=1), ys, reduction='mean')
t = time.process_time()
diffopt.step(loss) #(opt.zero_grad, loss.backward, opt.step)
print(len(fmodel._fast_params),"step", time.process_time()-t)
It's entirely possible the problem is not your fault. Either way, this is not normal and we'll get to the bottom of this. To manage expectations, I plan to take a look at this on Monday next week, but please do let me know if this is somehow more time-sensitive.
Do you have any lead ? Tell me if i can help you :smiley:
I can repro your issue on CPU using the script you provided. Thanks for narrowing it down to the step where the higher order graph is created. Running your code with track_higher_grads=False
in the diffopt creation yields a constant step time.
I'm not entirely sure why this is happening: the backward graph is bigger I suppose, but it should already be partly created. I'm not sure what is causing a bigger graph to be created from scratch. I'll have a think about this later this week, but I think we'd need to narrow it down a bit and/or try to figure out what the simplest repro of this is that doesn't have deep higher
dependencies, so I can go to the pytorch team for help if needed.
Sorry for the delay in looking into this. I believe this problem is linked to something in core pytorch, as flagged in this issue: https://github.com/pytorch/pytorch/issues/12635. Someone is looking into it AFAIK, so I'll report back if progress is made there.
I'm trying to reproduce this in a clean, minimal way, and am struggling a bit:
import torch
import torch.nn.functional as F
import torchvision
import higher
import time
import numpy as np
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--no-higher', action="store_true",
help="Dont use higher on inner loop.")
parser.add_argument('--no-unused', action="store_true",
help="Don't introduce unused hyperparameter.")
parser.add_argument('-v', '--verbose', action="store_true",
help="Print times.")
class Aug_model(torch.nn.Module):
def __init__(self, model, hyper_param=True):
super(Aug_model, self).__init__()
#### Origin of the issue ? ####
if hyper_param:
self._params = torch.nn.ParameterDict({
"hyper_param": torch.nn.Parameter(torch.Tensor([0.5])),
})
###############################
self._mods = torch.nn.ModuleDict({
'model': model,
})
def forward(self, x):
return self._mods['model'](x) #* self._params['hyper_param']
def __getitem__(self, key):
return self._mods[key]
def main(args):
print(args)
aug_model = Aug_model(
model=torch.hub.load('pytorch/vision:v0.4.2', 'resnet18', pretrained=False),
hyper_param=not args.no_unused #False will not extend step time
)
inner_opt = torch.optim.SGD(aug_model['model'].parameters(), lr=1e-2)
if not args.no_higher:
fmodel = higher.patch.monkeypatch(aug_model, copy_initial_weights=False)
diffopt = higher.optim.get_diff_optim(inner_opt, aug_model.parameters(),fmodel=fmodel)
times = []
for i in range(50):
# Fake data
xs = torch.rand(300, 3, 32, 32)
ys = torch.ones((300,), dtype=torch.long)
# Forward step
if args.no_higher:
logits = aug_model(xs)
loss = F.cross_entropy(F.log_softmax(logits, dim=1), ys, reduction='mean')
loss.backward()
else:
logits = fmodel(xs)
loss = F.cross_entropy(F.log_softmax(logits, dim=1), ys, reduction='mean')
t = time.process_time()
# Optimizer update
if args.no_higher:
inner_opt.step()
inner_opt.zero_grad()
else:
diffopt.step(loss) #(opt.zero_grad, loss.backward, opt.step)
time_of_step = time.process_time()-t
if args.verbose:
print("Time of step {}: {:.2f}ms.".format(i, time_of_step*1000))
times.append(time_of_step)
times = np.array(times)
end_vs_start = times[-30:].mean()/times[:30].mean()
print("Last 10 steps avg was {:.2f}% of first 10 steps.".format(end_vs_start*100))
if __name__ == "__main__":
args = parser.parse_args()
main(args)
The above code has quite a lot of overlap with yours, and yet doesn't exhibit the increasing time per step. However, the code I submitted to https://github.com/pytorch/pytorch/pull/52180 as a higher
-free minimum working example does, so I'm scratching my head...
Hi,
I was wondering if it was intended that the
diffopt.step(loss)
command takes an increasing time to execute when the number of state saved is increasing ? The step time should be constant as we perform a back-propagation on only one state, while the computation of a meta-gradient, at the end of the inner loops, should be longer as the number of states saved increase, right ?This code gives me :
EDIT :
After looking into the DifferentiableOptimizer code, it's seems that what's causing this slow down is the building of the graph of the derivative in
I'm not really familiar with the way autograd handle this but it seems the whole graph is computed at each call. Isn't it possible to keep the previous graph and extend it as the gradient tape expand, with the new states ?