facebookresearch / theseus

A library for differentiable nonlinear optimization
MIT License
1.78k stars 128 forks source link

Huber Loss Not Correct #497

Closed xphsean12 closed 1 year ago

xphsean12 commented 1 year ago

🐛 Bug

I add the huber loss to the current loss function. However, I found the value comes from the huber loss may with some problem

Steps to Reproduce

Steps to reproduce the behavior: First I will check the cost without Huber, the result is as below: image And then I add the huber with the a sufficiently large radius (100). According to the logic of Huber loss, the cost should maintain the same, or at least keeps a similar trend as the original cost. image However, the result is as below. image Besides, below is the cost with radius equal to 1. image

I add the Huber loss similar to the below example image

Thanks.

Expected behavior

Additional context

System Info

luisenp commented 1 year ago

Are you accounting for robust cost receiving the log radius? (see here)

xphsean12 commented 1 year ago

Are you accounting for robust cost receiving the log radius? (see here)

Sorry, I do not understand what you mean.

luisenp commented 1 year ago

Robust function receives the logarithm of the radius, not the radius itself. The loss is evaluated using a radius of exp(log_loss_radius). For example, if you pass a value of 1.0, the actual radius is 2.7182...

Also, looking at your original post. Can you share how you are computing the output numbers shown in the screen shot? Behavior is different depending on which method you call, as explained here.

xphsean12 commented 1 year ago

Thanks for your reply. The error on the screen shot was printed out by the the objective.error(), where objective is the instanlized Objective function. In the huber one, what different is that I warp a original Cost function with Huber loss

xphsean12 commented 1 year ago

I can use the example in the tutoral 1 as an example: Below is the code and error without adding the Huber: image image Below are the code and error with Huber: image image

luisenp commented 1 year ago

The reason you get this output is that RobustCostFunction is meant to evaluate loss(|| cf.weighted_error() || **2 ), where cf is the original cost function. Note that this is a dim=1 cost function, but because of the way our internal solver works, we need it to be consistent with the original dim=100 cost function. Thus we return a constant vector such that its norm is the same as the norm of the 1D robust cost (see this comment). You can confirm that the loss function with the given radius is not being applied by running the comparison below (with some numerical differences):

image

Note that the loss function is applied to the squared norm of the weighted error (see theory), not to each error element independently. Perhaps this what you were expecting?

xphsean12 commented 1 year ago

Thanks for your clarification. The value makes sense if the Huber function is on the square norm. However, from the reference, and the experience of using Ceres from my colleague, I think the Robust Loss is initially intended to used on the residual of the optimization problem. "For least squares problems where the minimization may encounter input terms that contain outliers, that is, completely bogus measurements, it is important to use a loss function that reduces their influence." Take the estimation problem y = kx+b as an example, where k and b is the parameter that needed to be estimated. If the (x3,y3) pair is the outliner, the Robust function should be able to suppress its effect. image In this case, f_i is the corresponding original loss function, which is (yi-(k X xi+b)), where k and b is the estimation of k and b.

Below is the real example in Ceres. The second column is the residual. One image is without Huber Loss, another is with Huber Loss image image

luisenp commented 1 year ago

I'm a bit confused by what you mean. In the writeup above, the loss is applied as $\rho(||f_i||^2)$. In Theseus terminology, the residuals f_i would be the result of cost_function.error(), so that the robust function is applied to cost_function.error().norm(dim=1) ** 2. If you want them to be applied separately, you can have a separate cost function for each residual term, rather than a single cost function that holds all residuals. Is this what you mean?

luisenp commented 1 year ago

You can try the code below, which prints the following.

image

Note that when passing the data to Theseus (near the end of the script), I'm making the 10 first points outliers, to show that the loss function is truncating their value. Is this the behavior you expect?

import torch

torch.manual_seed(0)

def generate_data(num_points=100, a=1, b=0.5, noise_factor=0.01):
    # Generate data: 100 points sampled from the quadratic curve listed above
    data_x = torch.rand((1, num_points))
    noise = torch.randn((1, num_points)) * noise_factor
    data_y = a * data_x.square() + b + noise
    return data_x, data_y

data_x, data_y = generate_data()

import theseus as th

# optimization variables are of type Vector with 1 degree of freedom (dof)
a = th.Vector(1, name="a")
b = th.Vector(1, name="b")

def quad_error_fn(optim_vars, aux_vars):
    a, b = optim_vars
    x, y = aux_vars
    est = a.tensor * x.tensor.square() + b.tensor
    err = y.tensor - est
    return err

optim_vars = a, b

# data is of type Variable. Make a separate variable for each data point, to have separate residuals
xs = [th.Variable(torch.zeros(1, 1), name=f"x_{i}") for i in range(data_x.shape[1])]
ys = [th.Variable(torch.zeros(1, 1), name=f"y_{i}") for i in range(data_y.shape[1])]

log_loss_radius = th.Vector(1, name="log_loss_radius", dtype=torch.float32)
objective = th.Objective()
# A separate cost function for each residual
for i in range(data_x.shape[1]):
    cf = th.AutoDiffCostFunction(
        optim_vars, quad_error_fn, 1, aux_vars=(xs[i], ys[i]), name=f"residual_{i}"
    )
    robust_cf = th.RobustCostFunction(
        cf, th.HuberLoss, log_loss_radius, name=f"robust_{cf.name}"
    )
    objective.add(robust_cf)
optimizer = th.GaussNewton(
    objective,
    max_iterations=15,
    step_size=0.5,
)
theseus_optim = th.TheseusLayer(optimizer, vectorize=True)

# Create inputs
theseus_inputs = {f"x_{i}": data_x[:, i : i + 1] for i in range(data_x.shape[1])}
theseus_inputs.update({f"y_{i}": data_y[:, i : i + 1] for i in range(data_x.shape[1])})
# Make the first 10 points outliers
for i in range(10):
    theseus_inputs[f"y_{i}"] += 100
theseus_inputs.update({"a": 2 * torch.ones((1, 1)), "b": torch.ones((1, 1))})

# Run optimization
with torch.no_grad():
    updated_inputs, info = theseus_optim.forward(
        theseus_inputs, optimizer_kwargs={"track_best_solution": True, "verbose": True}
    )
print("Best solution:", info.best_solution)

print(objective.error())
xphsean12 commented 1 year ago

Thanks for your reply! This is exactly what I want to achieve. Although the usage of Huber Loss is more complicated than I expected, still many thanks. Because in many cases, we are dealing with the loss function, the loss of which will be in a vector format. If we need to add it to any single element of the vector, it will be very tedious. Still, many many thanks for all your great work!!

luisenp commented 1 year ago

One idea would be to have a keyword in RobustCostFunction that makes it compute the loss as if it was dim different cost functions stacked together. I think this should be easy to implement by temporarily flattening the error to a shape (B x D, 1), applying the loss, then reshape to the proper output size (and similarly when computing the jacobians).

Would you be interested in contributing this enhancement?

xphsean12 commented 1 year ago

I can try. And here is one thing I am not sure: image In Pytorch, based on my knowledge, Tensor.max(a) is equal to torch.max(Tensor, dim = a), so what is the meaning of x.max(radius) here?

luisenp commented 1 year ago

This doesn't seem to be documented, but when the second argument is another tensor, this seems to work just like torch.maximum. I admit this makes the code confusing though.

image

xphsean12 commented 1 year ago

Hello. I have modified the Robostu function to fit the vector form error. Although the robustness of the code has not been verified, it works. Below is the example based on tutorial 1:

  1. Create the outliner with 2 batches image
  2. Rusult comparison between without/with robust loss (a) without Robust Loss image (b) with Robust Loss image
xphsean12 commented 1 year ago

By the way, below is the original RobustLoss

image

luisenp commented 1 year ago

This looks great, thanks for looking into this! Can you open a PR? We can also add a unit test that compares the results of the old "one cost function per residual" version with the result of this new version (they should be the same).

xphsean12 commented 1 year ago

This looks great, thanks for looking into this! Can you open a PR? We can also add a unit test that compares the results of the old "one cost function per residual" version with the result of this new version (they should be the same).

Sure. And based on my understanding, the old version one should be the same as the new one only if your residual is a vector with one dimension

luisenp commented 1 year ago

That's correct. To clarify, what I mean for the test is that, implementing as N individual robust cost functions with dim=1 (without the new flag turned on) should give the same result as implementing as a single dim=N robust cost function with the flag on.

xphsean12 commented 1 year ago

I am not sure for this, I did not add the flag. I solve this problem by checking the Jacobian shape of the original Cost Function (But I am not sure whether this suit all the cases)

xphsean12 commented 1 year ago

That's correct. To clarify, what I mean for the test is that, implementing as N individual robust cost functions with dim=1 (without the new flag turned on) should give the same result as implementing as a single dim=N robust cost function with the flag on.

Oh, I see what you mean. Then I should keep the original code rather than delete them.