TRI-ML / packnet-sfm

TRI-ML Monocular Depth Estimation Repository
https://tri-ml.github.io/packnet-sfm/
MIT License
1.24k stars 243 forks source link

Reprojected loss function #118

Open Wenchao-Du opened 3 years ago

Wenchao-Du commented 3 years ago

@VitorGuizilini where can I find the code and model for the paper "Robust Semi-Supervised Monocular Depth Estimation with Reprojected Distances (CoRL 2019 spotlight)" ? thank you

VitorGuizilini-TRI commented 3 years ago

Hi, thank you for your interest. We still have not added support for this loss function, but I'm planning to do that soon, I'll keep you informed.

pjckoch commented 3 years ago

Hi, are you still planning to release the code for the loss function? Thanks

pjckoch commented 3 years ago

As far as I understand it, part of it should be similar to the view_synthesis() function: https://github.com/TRI-ML/packnet-sfm/blob/c03e4bf929f202ff67819340135c53778d36047f/packnet_sfm/geometry/camera_utils.py#L27-L59

First, to get the world coordinates, call cam.reconstruct() with the lidar depth, then call the same function with the predicted depth. After that, use the predicted pose to project both world coordinates from lidar and prediction to the reference camera, using ref_cam.project(). Then, we can compute the euclidean distance between the two results as our loss, right?

So, something like the following should work, shouldn't it? Am I missing something?

import torch
from utils.depth import depth2inv, inv2depth
from utils.camera import Camera
from utils.image import match_scales

def reprojected_distance_loss(depth_pred: torch.Tensor, depth_gt: torch.Tensor, mask: torch.Tensor,
                                                    ref_cam: Camera, cam: Camera) -> torch.Tensor:
    # Reconstruct world points from target_camera
    world_points = cam.reconstruct(depth_gt, frame='w')
    world_points_pred = cam.reconstruct(depth_pred, frame='w')
    # Project world points onto reference camera (returns normalized pixel coordinates)
    ref_coords = ref_cam.project(world_points, frame='w')
    ref_coords_pred = ref_cam.project(world_points_pred, frame='w')
    return torch.linalg.norm(ref_coords[mask] - ref_coords_pred[mask], dim=1).mean()

masks = []
depth_gts = match_scales(depth_gt, preds, self.n, mode='nearest', align_corners=None)
depth_preds = [inv2depth(preds[i]) for i in range(self.n)]

for i in range(self.n):
    masks.append((depth_gts[i] > 0).detach())

for i in range(len(poses)):
    # Generate cameras for all scales
    cams, ref_cams = [], []
    for j in range(self.n):
        _, _, DH, DW = depth_preds[j].shape
        scale_factor = DW / float(W)
        cams.append(Camera(K=K.float()).scaled(scale_factor).to(device))
        ref_cams.append(Camera(K=ref_K.float(), Tcw=poses[i]).scaled(scale_factor).to(device))
    loss += sum([reprojected_distance_loss(depth_preds[i], depth_gts[i], masks[i], ref_cams[i], cams[i]) for i in range(self.n)])

loss /= len(poses)
loss /= self.n
iariav commented 2 years ago

@VitorGuizilini-TRI, any update on when you plan to add the implementation of the Reprojected Distance Loss function? Thanks

aartykov commented 2 years ago

@iariav, @Wenchao-Du, @VitorGuizilini-TRI I've implemented the Reprojected Distance Loss function and pull requested. Hopefully, it will satisfy your needs.

Best regards!