Closed calvinmccarter closed 2 years ago
@calvinmccarter Thanks for pointing this out - I get the same result on my machine and I'm trying to sort this out. At the moment I do not reproduce the error using pytorch's native L-BFGS algorithm (see below) so perhaps the issue is with torchmin, not torch.
Code:
xy = torch.tensor([1.0, 0.0], requires_grad=True)
optimizer = torch.optim.LBFGS([xy],
line_search_fn='strong_wolfe',
max_iter=10)
def closure():
optimizer.zero_grad()
loss = myfun(xy)
loss.backward()
return loss
loss = optimizer.step(closure)
print(loss)
Output:
tensor(-148.4767, grad_fn=<SubBackward0>)
@calvinmccarter - after a deeper look the issue here is your objective function.
Your objective function is only defined on the space of positive numbers (otherwise log(x) = nan). L-BFGS and the associated line search methods are not designed to handle constrained optimization problems like yours, so there is no guarantee that they will work properly. The fact that CG worked ok is just coincidence (you got lucky that the optimization path never crossed into the negative domain).
Perhaps we should build in a sanity check that looks for nans in the objective to help troubleshoot issues like these.
Yeah, I'm using the log as a log-barrier. Fwiw, the problem is fixed by using a custom autograd Function that returns -inf, as below:
class MyLog(torch.autograd.Function):
@staticmethod
def forward(ctx, inp):
result = torch.log(inp)
result[inp < 0] = float('-inf')
ctx.save_for_backward(inp)
return result
def backward(ctx, grad_output):
inp, = ctx.saved_tensors
grad_input = grad_output / inp
grad_input[inp < 0] = float('inf')
return grad_input
def myfun(xy):
x, y = xy[0], xy[1]
obj = (
round(r_1, 5) * MyLog.apply(x)
+ round(r_2+r_3, 5) * (x**2)
+ round(r_4, 5) * x * y
+ round(r_5, 5) * x
+ round(r_6, 5) * (y**2)
+ round(r_7, 5) * y
)
return obj
Arguably, there is a bug in torch/optim/lbfgs.py
:
low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[-1] else (1, 0)
Their Strong Wolfe zoom phase chooses (low_pos, high_pos) = (1, 0)
even when (bracket_f[0], bracket_f[-1]) = (finite_number, nan)
. I'll create an issue with torch.
The following problem succeeds with other methods (eg 'cg'), but fails with 'l-bfgs'. However, the error seems to arise in the PyTorch backend, so perhaps I should instead file this as an issue with PyTorch.
The error output is: