Open Taha-Bahadori opened 7 years ago
@soumith: I created a subclass to do this as follows. It works as I described above.
class ReduceLROnPlateauBT(ReduceLROnPlateau):
def __init__(self, optimizer, mode='min', factor=0.1, patience=10,
verbose=False, threshold=1e-4, threshold_mode='rel',
cooldown=0, min_lr=0, eps=1e-8, model=None):
super(ReduceLROnPlateauBT, self).__init__(optimizer, mode=mode,
factor=factor, patience=patience,
verbose=verbose, threshold=threshold,
threshold_mode=threshold_mode,
cooldown=cooldown, min_lr=min_lr, eps=eps)
self.model = model
self.model_state_dict = None if model is None else model.state_dict()
def step(self, metrics, epoch=None):
current = metrics
if epoch is None:
epoch = self.last_epoch = self.last_epoch + 1
self.last_epoch = epoch
if self.is_better(current, self.best):
self.best = current
self.num_bad_epochs = 0
if self.model is not None: # Saving good models
self.model_state_dict = self.model.state_dict()
else:
self.num_bad_epochs += 1
if self.in_cooldown:
self.cooldown_counter -= 1
self.num_bad_epochs = 0 # ignore any bad epochs in cooldown
if self.num_bad_epochs > self.patience:
self._reduce_lr(epoch)
self.cooldown_counter = self.cooldown
self.num_bad_epochs = 0
if self.model is not None: # Loading good models
self.model.load_state_dict(self.model_state_dict )
Hi, I am the puller of https://github.com/pytorch/pytorch/pull/1370
I currently do this via some code in the main loop. Since ReduceLROnPlateau only has access of the optimizer, and optimizer.state_dict() does NOT include its parameters (I guess this is some kind of bug), backtracking could not be done quite naturally.
# after optim.load_state_dict( ... )
In [17]: optim.state_dict()
Out[17]:
{'param_groups': [{'betas': (0.9, 0.999),
'eps': 1e-08,
'lr': 0.001,
'params': [139683866380144,
139683866381584,
139683866379472,
139683866381680],
'weight_decay': 0}],
'state': {}}
@soumith any idea?
Edit: I believe that optim.state_dict does not contain all parameter of a model (like BN's running average mean&std). Therefore we still need to have access to nn.Module even without this issue.
@Taha-Bahadori Would you plz make it a PR?
btw, self.model_state_dict = self.model.state_dict()
would NOT make a copy of the current state. You might have to pickle the state_dict for future loading.
You would see what happens by running this snippet.
import torch
m = torch.nn.Linear(1,2)
optim = torch.optim.Adam(m.parameters())
state_dict = m.state_dict()
print(state_dict)
m.state_dict()['weight'][0]=1000
print(state_dict)
@Jiaming-Liu I think there should be a mistake in your code snippet. Here is an example that shows the above saving and loading state_dict
should work:
import torch
import torch.nn as nn
class M(nn.Module):
def __init__(self):
super(M, self).__init__()
self.mem = nn.Parameter(torch.zeros(1))
m = M()
print m
print m.state_dict()
# Saving the current state_dict
sd = m.state_dict()
# Changing the value of the parameter
m.mem.data = torch.ones(1)
print m.state_dict()
# Now loading the original zero parameter back
m.load_state_dict(sd)
print m.state_dict()
@Taha-Bahadori I think this snippet would be closer to the real use-case. Note that in-place tensor operation is used in optimizer.step()
. That is, m.mem.data = torch.ones(1)
in your snippet should be something like m.mem.data.copy_(torch.ones(1))
instead.
As well, your ReduceLROnPlateauBT
ignores the state_dict of optimizer
, which contains some important history info (like momentum for SGD).
import torch
net = torch.nn.Linear(1,2)
optim = torch.optim.Adam(net.parameters())
state_dict = net.state_dict()
print(state_dict)
x = torch.FloatTensor([[1],[2]])
x = torch.autograd.Variable(x)
y = torch.FloatTensor([[0,1],[2,3]])
y = torch.autograd.Variable(y)
loss = torch.nn.functional.mse_loss(net(x),y)
loss.backward()
optim.step() # Changing the value of the parameter
net.load_state_dict(state_dict)
print(net.state_dict())
@vincentqb can you make a call on this?
Given the cost associated to backtracking, this mechanism should be left in the control of the user. The scheduler could have flag that indicates its status though, so the user can save/load based on that flag, see comment.
This is really a nice issue, so, what is the final decision of "needs research"?
Is it possible to implement a simple backtracking for the
ReduceLROnPlateau
module?That is, store the best model coefficients and reload it upon rate reduction.
In my experiments, this helps speed up learning, though it might be expensive for very large models.
cc @vincentqb