kocurvik / robust_self_calibration

25 stars 5 forks source link

Predicting Focal Length by Batched Operations #2

Open jytime opened 6 months ago

jytime commented 6 months ago

Hi Viktor,

Thanks for your great work! I am trying to estimate relative pose for two images given their 2D matches. By now I found the estimation of fundamental matrix is quite good (by 7pt and 8pt minimal solvers with RANSAC), but the quality of relative pose is low due to the lack of focal length.

In my trials, I assume the focal length as 1.2 times of the longer image side and the principal point as the image center. The fundamental matrix is converted the essential matrix, and then to R and t. Qualitatively, the pose result is quite good with ground truth focal length and assumed principal point, but drops clearly with the assumed focal length. So I am looking for potential solutions, and (luckily!) found your work :)

I'm looking to conduct predictions using batched operations as I may have more than 10,000 batches for each run. For instance, when using RANSAC+7p, I select different sets of points, construct pts1 and pts2 with the shape of B x Ransac_num x 7 x 2, reshape them to (B*Ransac_num) x 7 x 2, build the A matrix, conduct svd, and find the solutions.

Assume that we have B x Ransac_num x 3 x 3 fundamental matrix already, an naive idea here may be: (a) using RFC to filter out fundamental matrix and pick the best one, (b) estimating focal lengths by focal_svd or bougnoux_torch, (c) decomposing fundamental matrix to essential matrix by extracted focal length, and (d) finally achieving relative pose. The operations involved here look easy for batching, e.g., using torch.linalg.svd. Another option could be to convert your method to be batched, but this seems quite challenging.

Do you think the solution above may achieve a good result, and may I ask what can I do to further improve the performance? Furthermore, if we know that for all the batches, 'pts1' originates from the same image (that is, I am predicting the relative poses from 'B' images to a single query image), could this potentially simplify the problem?

Best, Jianyuan

kocurvik commented 6 months ago

If I understand correctly, you need to calculate the relative pose in a differentiable manner.

Ideally in step (a) you would also use the remaining points to select the inliers. You can use RFC to save some computation time here.

In (b) you can use the bougnoux_torch function. Or you can check out bougnoux_rybkin which is even simpler. This uses the same formulas which we use for RFC. In RFC we use only one of the three possible formulas for efficiency. In the linked implementation we use all three possible formulas (check out Oleh Rybkin's Bachelor Thesis) and the if statements are there to check for degeneracy as each of these simpler formulas have additional degenerate configurations which the Bougnoux formula does not hav. However, in practice they do not occur so often. So you can just use the first formula. Note that there is a bug in the implementation, but I did not fix it since it isn't used in the code. The correct version should calculate the argmax of absolute values of (den1, den2, den3).

The simplified implementation would look like this:

    den = f_11*f_12*f_31*f_33-f_11*f_13*f_31*f_32+f_12**2*f_32*f_33-f_12*f_13*f_32**2+f_21*f_22*f_31*f_33-f_21*f_23*f_31*f_32+f_22**2*f_32*f_33-f_22*f_23*f_32**2
    num = -f_33*(f_12*f_13*f_33-f_13**2*f_32+f_22*f_23*f_33-f_23**2*f_32)
return num / den

You can call it with F.T instead of F to get the other focal length.

I am not sure if there would be problems with the gradients so maybe you can try both of these approaches. Note that if you do RFC anyways this performs almost the full computation, but at the end you only calculate the sign of (num * den). So in your case you can calculate both num and den, then check their signs, discard the Fs if they result in the negative sign and then when you finally select the right F you can just calculate num / den.

Note that the function focal_svd that you link is specifically for the case when one camera is calibrated and the other is not. You should not use that.

jytime commented 6 months ago

Hi Viktor,

Thanks for your prompt reply. I quickly implemented had a try. bougnoux_rybkin and bougnoux_torch show similar results in my side (testing in IMC phototourism). However, I found the prediction of focal length is quite unstable. Therefore, instead of estimating the focal length from the F mat with most inliers, I pick the top 50 F mat, extract their focal lengths, and compute their average. Do you think such a operation make sense mathematically?

At the same time, are there any reliable metrics that can measure the accuracy of predicted focal lengths only with F mat and 2D matches? If so, we may further filter out some inaccurate predictions.

In case someone else may also need such implementation, I shared my code below:

# B x all_fmat_num x 3 x 3
B, all_fmat_num, 3, 3 = all_fmat.shape

# check RFC
rfc_mask = calculate_RFC_batched(all_fmat.reshape(B * all_fmat_num, 3, 3))
rfc_mask = rfc_mask.reshape(B, all_fmat_num)

# If RFC is False, set the inlier number and the indicator as 0
residual_indicator = residual_indicator * rfc_mask
inlier_num_all = inlier_num_all * rfc_mask

# pick 50 good F from all_fmat_num
sorted_values, sorted_indices = torch.sort(inlier_num_all, dim=1, descending=True)
good_num = 50
good_indices = sorted_indices[:, :good_num]
good_fmat = all_fmat[batch_index[:, 0:1].expand(-1, good_num), good_indices]

# estimate the focal length for left and right camera
f1 = bougnoux_rybkin_batched(good_fmat.reshape(B * good_num, 3, 3)).sqrt()
f2 = bougnoux_rybkin_batched(good_fmat.transpose(-1, -2).reshape(B * good_num, 3, 3)).sqrt()

f1 = f1.reshape(B, good_num)
f2 = f2.reshape(B, good_num)

# # since we know f1 corresponds to only camera
# average f1 B x good_num -> B to pick its mean prediction
# use compute_valid_average to avoid the potential NaN or Inf
# use maxsize to only pick focal lengths in the range [maxsize/valid_thres, valid_thres*maxsize]
# valid_thres is 4, we assume focal length should not be very far away from the longer of image size
f1_mean = compute_valid_average(f1, maxsize=maxsize, thres=valid_thres)

# compute the ratio of f2 to f1, so that map the predicted f2 to f1_mean
f21_ratio = compute_valid_average(f2) / compute_valid_average(f1)

# focal length for camera 1 and 2 (in pixel)
# e.g., for a 1024x1024 image, its f1_real is usually 1057.6292 or something close
f1_real = f1_mean.mean()
f2_real = f1_real * f21_ratio

def bougnoux_rybkin_batched(F, checkRFC=False):
    # Extract elements of F assuming F is a Bx3x3 torch tensor
    f_11, f_12, f_13 = F[:, 0, 0], F[:, 0, 1], F[:, 0, 2]
    f_21, f_22, f_23 = F[:, 1, 0], F[:, 1, 1], F[:, 1, 2]
    f_31, f_32, f_33 = F[:, 2, 0], F[:, 2, 1], F[:, 2, 2]

    # Compute the denominators for each element in the batch
    den_1 = (
        f_11 * f_12 * f_31 * f_33
        - f_11 * f_13 * f_31 * f_32
        + f_12.pow(2) * f_32 * f_33
        - f_12 * f_13 * f_32.pow(2)
        + f_21 * f_22 * f_31 * f_33
        - f_21 * f_23 * f_31 * f_32
        + f_22.pow(2) * f_32 * f_33
        - f_22 * f_23 * f_32.pow(2)
    )
    den_2 = (
        f_11.pow(2) * f_31 * f_33
        + f_11 * f_12 * f_32 * f_33
        - f_11 * f_13 * f_31.pow(2)
        - f_12 * f_13 * f_31 * f_32
        + f_21.pow(2) * f_31 * f_33
        + f_21 * f_22 * f_32 * f_33
        - f_21 * f_23 * f_31.pow(2)
        - f_22 * f_23 * f_31 * f_32
    )
    den_3 = (
        f_11.pow(2) * f_31 * f_32
        - f_11 * f_12 * f_31.pow(2)
        + f_11 * f_12 * f_32.pow(2)
        - f_12.pow(2) * f_31 * f_32
        + f_21.pow(2) * f_31 * f_32
        - f_21 * f_22 * f_31.pow(2)
        + f_21 * f_22 * f_32.pow(2)
        - f_22.pow(2) * f_31 * f_32
    )

    # Compute the numerators based on the index and return the result for each batch element
    num_1 = -f_33 * (f_12 * f_13 * f_33 - f_13.pow(2) * f_32 + f_22 * f_23 * f_33 - f_23.pow(2) * f_32)
    num_2 = -f_33 * (f_11 * f_13 * f_33 - f_13.pow(2) * f_31 + f_21 * f_23 * f_33 - f_23.pow(2) * f_31)
    num_3 = -f_33 * (f_11 * f_13 * f_32 - f_12 * f_13 * f_31 + f_21 * f_23 * f_32 - f_22 * f_23 * f_31)

    if checkRFC:
        condition1 = num_1 * den_1 < 0
        condition2 = num_2 * den_2 < 0
        condition3 = num_3 * den_3 < 0
        return ~(condition1 | condition2 | condition3)

    # Stack the denominators and find the index of the maximum for each batch element
    dens = torch.stack([den_1, den_2, den_3], dim=-1)
    # Make dens absolute
    i_max = torch.argmax(dens.abs(), dim=-1)

    nums = torch.stack([num_1, num_2, num_3], dim=-1)

    # Select the numerator and denominator based on the max index for each batch element
    batch_indices = torch.arange(F.shape[0], device=F.device)
    selected_nums = torch.gather(nums, 1, i_max.unsqueeze(-1)).squeeze(-1)
    selected_dens = torch.gather(dens, 1, i_max.unsqueeze(-1)).squeeze(-1)

    return selected_nums / selected_dens

def compute_valid_average(tensor, maxsize=None, thres=4):
    """
    Computes the average of each row in the tensor, excluding NaN and Inf values.

    Parameters:
    - tensor (Tensor): A 2D tensor of shape BxN.

    Returns:
    - Tensor: A 1D tensor of length B, where each element is the average of the corresponding row,
              excluding NaN and Inf values.
    """

    # Step 1: Create a mask for valid numbers
    valid_mask = torch.isfinite(tensor)

    if maxsize is not None:
        thres_mask = torch.logical_and(tensor <= (maxsize * thres), tensor >= (maxsize / thres))
        valid_mask = torch.logical_and(valid_mask, thres_mask)

    # Step 2: Compute the sum of valid elements for each row
    valid_sum = torch.where(valid_mask, tensor, 0).sum(dim=1)

    # Step 3: Count the number of valid elements in each row
    valid_count = valid_mask.sum(dim=1)

    # Step 4: Compute the average, avoiding division by zero
    average = valid_sum / valid_count.clamp(min=1)

    return average

def calculate_RFC_batched(F):
    # Ensure F is a PyTorch tensor of shape [batch_size, 3, 3]

    den1 = F[:, 0, 0] * F[:, 0, 1] * F[:, 2, 0] * F[:, 2, 2] - F[:, 0, 0] * F[:, 0, 2] * F[:, 2, 0] * F[:, 2, 1] + \
           F[:, 0, 1] ** 2 * F[:, 2, 1] * F[:, 2, 2] - F[:, 0, 1] * F[:, 0, 2] * F[:, 2, 1] ** 2 + \
           F[:, 1, 0] * F[:, 1, 1] * F[:, 2, 0] * F[:, 2, 2] - F[:, 1, 0] * F[:, 1, 2] * F[:, 2, 0] * F[:, 2, 1] + \
           F[:, 1, 1] ** 2 * F[:, 2, 1] * F[:, 2, 2] - F[:, 1, 1] * F[:, 1, 2] * F[:, 2, 1] ** 2

    num1 = -F[:, 2, 2] * (F[:, 0, 1] * F[:, 0, 2] * F[:, 2, 2] - F[:, 0, 2] ** 2 * F[:, 2, 1] + \
                          F[:, 1, 1] * F[:, 1, 2] * F[:, 2, 2] - F[:, 1, 2] ** 2 * F[:, 2, 1])

    condition1 = num1 * den1 < 0

    den2 = F[:, 0, 0] * F[:, 1, 0] * F[:, 0, 2] * F[:, 2, 2] - F[:, 0, 0] * F[:, 2, 0] * F[:, 0, 2] * F[:, 1, 2] + \
           F[:, 1, 0] ** 2 * F[:, 1, 2] * F[:, 2, 2] - F[:, 1, 0] * F[:, 2, 0] * F[:, 1, 2] ** 2 + \
           F[:, 0, 1] * F[:, 1, 1] * F[:, 0, 2] * F[:, 2, 2] - F[:, 0, 1] * F[:, 2, 1] * F[:, 0, 2] * F[:, 1, 2] + \
           F[:, 1, 1] ** 2 * F[:, 1, 2] * F[:, 2, 2] - F[:, 1, 1] * F[:, 2, 1] * F[:, 1, 2] ** 2

    num2 = -F[:, 2, 2] * (F[:, 1, 0] * F[:, 2, 0] * F[:, 2, 2] - F[:, 2, 0] ** 2 * F[:, 1, 2] + \
                          F[:, 1, 1] * F[:, 2, 1] * F[:, 2, 2] - F[:, 2, 1] ** 2 * F[:, 1, 2])

    condition2 = num2 * den2 < 0

    # Return a tensor of booleans indicating False where either condition1 or condition2 is True
    return ~(condition1 | condition2)
Parskatt commented 3 months ago

I'm having similar issues of instability. Although I guess I shouldn't be surpised, but just as an example: image

Parskatt commented 3 months ago

I feel stupid, can it really be this sensitive? image

Parskatt commented 3 months ago

Is there something special with using precisely 1?

EDIT: No, there does not seem to anything special, it just seems I was very lucky there (specifically 50 iters for 1 gives me correct intrinsics, but 51 iters e.g. does not)

Parskatt commented 3 months ago

Ok, so basically if iterations don't converge I shouldn't trust it I guess.

kocurvik commented 3 months ago

Ok, so basically if iterations don't converge I shouldn't trust it I guess.

Yes, whenever it does not converge in time it is usually a sign of a bad estimate. We did some additional experiments for the camera-ready version (not yet on arxiv) and near the degenerate configurations the method is more likely to not converge in 50 iters. In some configurations it is very difficult to get a good estimate of focal lengths from the Fundamental matrix. Our method suffers from this as well, but to a lesser extent than previous approaches. You can observe this also in the paper. If you check Fig. 4 of the paper then for Phototourism about 30 percent of samples have relative error greater than 30% in f for our method.

It is actually surprising that you got a good estimate by chance. Maybe there is some oscillation around correct intrinsic even in these failure cases.

kocurvik commented 3 months ago

Hi Viktor,

Thanks for your prompt reply. I quickly implemented had a try. bougnoux_rybkin and bougnoux_torch show similar results in my side (testing in IMC phototourism). However, I found the prediction of focal length is quite unstable. Therefore, instead of estimating the focal length from the F mat with most inliers, I pick the top 50 F mat, extract their focal lengths, and compute their average. Do you think such a operation make sense mathematically?

At the same time, are there any reliable metrics that can measure the accuracy of predicted focal lengths only with F mat and 2D matches? If so, we may further filter out some inaccurate predictions.

In case someone else may also need such implementation, I shared my code below:

Sorry, I did not notice your comment earlier. The instability occurs due to the singularity when the principal axes are coplanar (they intersect) and configurations which are close to this. This happens a lot in practice. I don't know if averaging results is a way to go, since you would be averaging values obtained for slightly different Fs representing the same configurations. Did you manage to get good results using it?

jytime commented 3 months ago

Hi thanks for the reply. The short answer is averaging can achieve more stable results in practice, but the improvement is not significant.