facebookresearch / theseus

A library for differentiable nonlinear optimization
MIT License
1.74k stars 124 forks source link

Batching of error function #565

Closed jytime closed 1 year ago

jytime commented 1 year ago

Hi Theseus team,

Thanks for your great work! It is really helpful to have second order optimisers. Our team is trying to use Theseus to conduct Bundle Adjustment in our structure from motion project. The use case is similar to the theseus BA example.

We've written an implementation that aligns with our coordinates and data, following your BA example. However, we're experiencing very slow optimization speed. For a scene with 100 cameras, 3500 3D points, and 43000 observations, it's taking about 30 seconds per iteration using LevenbergMarquardt (650 seconds for 20 iterations). We observe similar speeds when using the reproduction error which computes jacobian matrix in a closed form. Our mean l2 error drops from 0.0620 to around 0.0153 after 20 iterations in this case (vectorized_err_fn = False).

We initially thought the slow speed was due to our error function not being batched, as in issue #290. But, with the error function batched (vectorized_err_fn = True in the code below), although the speed would be much faster to around 1 second per iteration, the optimisation result is not satisfactory. Now the mean l2 error (print(torch.norm(err, dim=-1).mean())) can only drop to 0.0481 after 20 iterations. I tried to adjust some hyperparameters like step_size but it does not work clearly.

If you may want to reproduce the result, the data sample "badataset_208.pkl" can be found here.

I was wondering is our current implementation correct? Apart from batching the error function, are there any other strategies we could use to speed up convergence? Any guidance would be greatly appreciated.

Best regards, Jianyuan

import logging
import os
import pathlib
import random
import time
from typing import Dict, List, Type

import hydra
import numpy as np
import omegaconf
import torch

import theseus as th
import theseus.utils.examples as theg
import pickle

log = logging.getLogger(__name__)

def get_batch(
    ba: theg.BundleAdjustmentDataset,
    orig_poses: Dict[str, torch.Tensor],
    orig_points: Dict[str, torch.Tensor],
) -> Dict[str, torch.Tensor]:
    retv = {}
    for cam in ba.cameras:
        retv[cam[0].name] = orig_poses[cam[0].name].clone()

    for pt in ba.points:
        retv[pt.name] = orig_points[pt.name].clone()
    return retv

def base_proj_points(pose, point3d, fl):
    # P * R + t
    transfer = point3d.tensor.unsqueeze(1) @ (pose.tensor[...,:3])
    transfer = transfer.squeeze(1)  # N, 1, 3
    transfer = transfer + pose.tensor[...,3]

    # to 2d point
    tmp = transfer[...,:2] / transfer[...,2:] 
    point_projection = tmp * fl.tensor
    return point_projection

def proj_point_err(optim_vars, aux_vars):
    pose, point3d = optim_vars
    point2d, fl = aux_vars
    point_projection = base_proj_points(pose, point3d, fl)
    err = point_projection - point2d.tensor
    err = torch.clamp(err, max=1, min=-1)

    if err.shape[0]>1:
        print(torch.norm(err, dim=1).mean())
    return err

def proj_point_wo_pose_opt_err(optim_vars, aux_vars):
    point3d = optim_vars[0]
    pose, point2d, fl = aux_vars
    point_projection = base_proj_points(pose, point3d, fl)
    err = point_projection - point2d.tensor
    err = torch.clamp(err, max=1, min=-1)

    if err.shape[0]>1:
        print(torch.norm(err, dim=1).mean())
    return err

def proj_point_err_batch(optim_vars, aux_vars):
    other_pose, points3d = optim_vars
    first_pose, fl, image_id, point_id, image_point = aux_vars

    point_id = point_id.tensor.long()   # [1, 43855]
    image_id = image_id.tensor.long()   # [1, 43855]

    # [100, 3, 4]]
    pose = torch.cat([first_pose.tensor, other_pose.tensor.reshape(-1, 3, 4)], dim=0)
    # [3543, 3]
    p3d = points3d.tensor.reshape(-1, 3)

    picked_p3d = p3d[point_id[0]]       # [43855, 3]
    picked_pose = pose[image_id[0]]     # [43855, 3, 4]
    picked_fl = fl.tensor[0][image_id[0]]   # [43855]

    # P * R + t
    transfer = picked_p3d.unsqueeze(1) @ picked_pose[...,:3]
    transfer = transfer.squeeze(1)  # N, 1, 3
    transfer = transfer + picked_pose[...,3]

    # to 2d point
    tmp = transfer[...,:2] / transfer[...,2:] 
    point_projection = tmp * picked_fl.unsqueeze(-1)
    # 1xNx2
    err = point_projection.unsqueeze(0) - image_point.tensor
    err = torch.clamp(err, max=1, min=-1)
    # 1xNx2
    print(torch.norm(err, dim=-1).mean())
    err = err.abs().mean(dim=1)
    # 1x2
    return err

def run_ba(vectorized_err_fn=True):
    DATA_PATH = 'badataset_208.pkl'

    with open(DATA_PATH, 'rb') as f:  
        ba = pickle.load(f)    

    device = "cuda"
    # Set up objective
    objective = th.Objective(dtype=torch.float64)
    weight = th.ScaleCostWeight(torch.tensor(1.0, device=device).to(dtype=ba.cameras[0][0].dtype))

    if vectorized_err_fn:
        fl_list = []
        pose_list = []
        image_id_list = []
        point2d_list = []
        point3d_list = []
        point3d_id_list =[]

        for i in range(len(ba.cameras)):
            # assum calib as 0
            pose, fl, _, _ = ba.cameras[i]
            pose_list.append(pose.tensor)
            fl_list.append(fl.tensor)

        for j in range(len(ba.points)):
            point3d_list.append(ba.points[j].tensor)

        for obs in ba.observations:
            image_id_list.append(obs.camera_index)
            point2d_list.append(obs.image_feature_point.tensor)
            point3d_id_list.append(obs.point_index)

        # (1 x N_img)
        fl =  th.Vector(tensor=torch.cat(fl_list).reshape(1,-1).to(device).double(), name=f"fl")
        poses = torch.cat(pose_list).to(device).double()

        # fix the first camera to stabilize the BA system
        first_pose = th.Variable(poses[0:1], name="first_pose")
        other_pose = th.Vector(tensor=poses[1:].reshape(1,-1), name=f"other_pose")

        # 1 x N_img
        image_id_list = np.array(image_id_list)
        image_id = th.Variable(torch.from_numpy(image_id_list)[None].to(device).double(), name = "image_id")
        # 1 x N_observation x 2
        point2d = th.Variable(torch.cat(point2d_list)[None].to(device).double(), name = "point2d")

        # 1 x N_point3d x 3 -> 1 x (N_point3d * 3)
        point3d = th.Vector(tensor = torch.cat(point3d_list).reshape(1,-1).to(device).double(), 
                                    name = f"point3d")

        point3d_id_list = np.array(point3d_id_list)
        point3d_id = th.Variable(torch.from_numpy(point3d_id_list)[None].to(device).double(), 
                                 name = "point3d_id")            

        theseus_inputs = {"fl": fl.tensor,
                           "point3d_id": point3d_id.tensor,"image_id": image_id.tensor,
                           "point2d": point2d.tensor,"first_pose": first_pose.tensor,
                           "other_pose": other_pose.tensor,"point3d": point3d.tensor,}

        optim_vars = [other_pose, point3d]
        aux_vars = [first_pose, fl, image_id, point3d_id, point2d]

        cost_function = th.AutoDiffCostFunction(
                    optim_vars, proj_point_err_batch, 2, 
                    aux_vars=aux_vars, cost_weight=weight,)

        objective.add(cost_function)
    else:
        for obs in ba.observations:
            # assum calib as 0
            pose, fl, _, _ = ba.cameras[obs.camera_index]
            point3d = ba.points[obs.point_index]
            point2d = obs.image_feature_point
            if obs.camera_index == 0:
                # fix the first camera to stabilize the BA system
                optim_vars = [point3d]
                aux_vars = [pose, point2d, fl]
                cost_function = th.AutoDiffCostFunction(
                        optim_vars, proj_point_wo_pose_opt_err, 2, 
                        aux_vars=aux_vars, cost_weight=weight,)
            else:
                optim_vars = [pose, point3d]
                aux_vars = [point2d, fl]
                cost_function = th.AutoDiffCostFunction(
                        optim_vars, proj_point_err, 2, 
                        aux_vars=aux_vars, cost_weight=weight,)

            objective.add(cost_function)

    dtype = objective.dtype

    loss_val = objective.error()

    # Set up optimizer
    optimizer_cls: Type[th.NonlinearLeastSquares] = getattr(
        th, "LevenbergMarquardt"       
        # GaussNewton is not applicable here because "RuntimeError: There was an error while running the linear optimizer
        # Original error message: linalg.cholesky: (Batch element 0): 
        # The factorization could not be completed because the input is not positive-definite (the leading minor of order 3 is not positive-definite)"
    )

    optimizer = optimizer_cls(
        objective,
        max_iterations=20,
        step_size=0.1,)

    # Set up Theseus layer
    theseus_optim = th.TheseusLayer(optimizer)

    if not vectorized_err_fn:
        orig_poses = {cam[0].name: cam[0].tensor.clone() for cam in ba.cameras}
        orig_points = {point.name: point.tensor.clone() for point in ba.points}
        theseus_inputs = get_batch(ba, orig_poses, orig_points)

    start_time = time.time_ns()

    theseus_outputs, info = theseus_optim.forward(
        input_tensors=theseus_inputs,
        optimizer_kwargs={
            "verbose": "true",
            "track_err_history": "true",},)

    end_time = time.time_ns()
    print(f"Epoch took {(end_time - start_time) / 1e9: .3f} seconds")

    # if vectorized_err_fn = False
    # Epoch took  656.086 seconds 
    # tensor(0.0153, device='cuda:0', dtype=torch.float64)                                                                                                                          

    # if vectorized_err_fn = True
    # Epoch took 30.870  seconds 
    # tensor(0.0481, device='cuda:0', dtype=torch.float64)

    return theseus_outputs, info

def main():
    seed = 0
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    run_ba()

if __name__ == "__main__":
    main()
luisenp commented 1 year ago

Hi @jytime, thanks for your interest and reporting this.

I tried to run the script but got an import error when loading the pickle file inline 108:

Exception has occurred: ModuleNotFoundError
No module named 'util'

I'll still try to go over the code to see if I spot something, and let you know if I have any comments or questions.

jytime commented 1 year ago

Hi @luisenp ,

Thanks for your reply, and sorry I found I changed the structure a bit in the pkl file before.

I have updated it and the new one can be found here. This can be opened by standard theseus library. Please let me know if it still does not work :)

luisenp commented 1 year ago

Hi @jytime, thanks for the update, I'm able to run now.

I ran some profiling, and it seems that a significant reason for the slow down is that you are using the default torch dense solver, which is not very fast. You can change this by passing linear_solver_cls= when you construct the optimizer. Options are: th.CholmodSparseSolver (cpu), th.LUCudaSparseSolver or th.BaspachoSparseSolver (cpu/gpu).

Also, turns out we had some unexpected behavior (see #569), so the linearization was actually being run on CPU not GPU, even though you were passing GPU tensors. The linear system seems quite large, so I wasn't able to run on GPU with the device I was testing on. So, I used BaspachoSparseSolver (on CPU) and reduced optimization time to about 7.7 seconds.

Now most of the computation is spent on the jacobians/error computation. We automatically group/vectorize similar cost functions to speed up computation (I checked only 2 groups are created in your case). However there are still a lot of individual costs, and we have some for loops in order to populate the data for the vectorized costs. So, some manual batching as you are doing would be helpful.

I'm not sure what the reason for the discrepancies are. Have you double checked that you can reproduce the same error/jacobians computaton for the same input in both versions?

mhmukadam commented 1 year ago

reduced optimization time to about 7.7 seconds

@luisenp, wohoo nice!

jytime commented 1 year ago

Hi @luisenp ,

Thanks for figuring this out! It is really helpful. I will write a closed-form jacobian computation function for the code and then check if we can run GPU linear solver on this, and the best speed we can have.

To benefit others who may encounter a similar problem, I will keep this issue open and report our progress as we advance.

jytime commented 1 year ago

Hi @luisenp ,

I am running experiments with a closed-form jacobian implemented. There may be something worth mention, probably relate to the cuda/cpu problem as in #569.

When I ran identical code, I found that the GPU usage of LUCudaSparseSolver was around 30 times higher compared to the BaspachoSparseSolver. While I expected the BaspachoSparseSolver to offer significant savings in GPU usage, the magnitude of the difference seems somewhat weird.

For your convenience, here is a simple sample to reproduce the observation. When we use the BaspachoSparseSolver (which is set as the default in the code), and monitor the GPU usage with nvidia-smi, we find that the peak GPU usage only reaches 1068 MiB. However, when we switch to the LUCudaSparseSolver (which can be activated by uncommenting the line cfg.ba.linear_solver_cls = "LUCudaSparseSolver"), the GPU peak would jump to ~25000 MiB. Although BaspachoSparseSolver is still around 2x faster than LUCudaSparseSolver in this case (wow excellent), I doubt the BaspachoSparseSolver may still run on CPU instead of GPU.

For some additional context: I built baspacho from source. CUDA version is 11.7 and PyTorch version is 2.0.1. I have double-checked that baspacho was successfully compiled and it supports CUDA. All the variables (including objective) has been put into CUDA. I also checked _linearize_jacobian_impl, where the tensors looks all on CUDA. The screenshots of time analysis by cprofile can also be found here.

Thanks for your attention and I hope this is not because I installed something wrongly...

maurimo commented 1 year ago

Hi @jytime and @luisenp! BaSpaCho's author here (and also author of Theseus's LUCuda's wrapper!)

So, about LUCuda I can't say exactly what is going on as it's closed source. As a start it requires two separate factors (L and U) which are stored in an opaque context, where I think some data is further duplicated. In the other hand BaSpaCho is stripped to the bone in terms of memory usage, as you just allocate a (batched) factor and Cholesky decomposition occurs in-place.

About CPU/GPU in BaSpaCho: BaSpaCho is meant to be a rather low-level library, and entirely lacks the logic of possibly moving data between CPU/GPU. If the factor is a tensor whose .dev()="cuda" then better be using a sparse solver configured to work with GPU data, otherwise BaSpaCho will blindly try to access data from the CPU and you will get a segmentation fault. As things are wrapped in Theseus, this is established once and for all when the SymbolicDecomposition (baspacho_solver.{cpp,h}) C++-wrapped object is created, which in turn occurs inside the BaspachoSparseSolver Python class (baspacho_sparse_solver.py), according to the dev= argument. Starting from the creation of BaspachoSparseSolver the switch is set once and for all, and data fed to those classes is expected to be on the right device since then.

Why this choice is set so early and cannet be changed later? This is related to the best strategy being possibly quite different between GPU/CPU, a symbolic decomposition meant for the GPU could be used (in theory, not currently supported) on the CPU, but the "best thing to do" could be quite different, eg the CPU would in general favour a sparser layout (with more branching) while the GPU might prefer a denser factor layout to use less branching even if there is some more zero-fill.

luisenp commented 1 year ago

Thanks for the clarification @maurimo.

The funny thing is that as a result of this discussion, I noticed a bug in our optimizer that actually prevents BaSpaCho to be used in CPU if torch.cuda.is_available(), so I think you wouldn't have been able to run it on CPU, even if you tried, @jytime.

I'll put put a fix for this soon, so this ended up helping us catch a bug anyway!

jytime commented 1 year ago

Hi @maurimo and @luisenp ,

Thanks for your responses and the great BaSpaCho solver. Really helpful!

I am quite glad that my dirty usage helps catch a bug in a funny way (we are hoping to use the second order optimisation of theseus for training in a large scale way, so a bit concerned about the speed, sorry XD). I would keep trying to accelerate the optimisation of our code and "hammer" the theseus system. I would share the findings during this process and hope it could be beneficial to the community in some way.