dingdingcai / GSPose

MIT License
75 stars 7 forks source link

confused about the GS_Refiner #13

Closed ansj11 closed 1 month ago

ansj11 commented 1 month ago

Hi~, your work is awesome! but I have a question about that GS_Refiner how to optimize the _delta_R and _delta_T with not inputting the renderer or adding to the init_camera? Is there some magic in the code?

` def multiple_refine_pose_with_GS_refiner(obj_data, init_pose, gaussians, device): def GS_Refiner(image, mask, init_camera, gaussians, return_loss=False): if image.dim() == 4: image = image.squeeze(0) if image.shape[2] == 3: image = image.permute(2, 0, 1) # 3xSxS if mask is None: mask = torch.ones_like(image[0]) if mask.dim() == 2: mask = mask[None, :, :] if mask.dim() == 4: mask = mask.squeeze(0) if mask.shape[2] == 1: mask = mask.permute(2, 0, 1) # 1xSxS

    assert(image.dim() == 3 and image.shape[0] == 3), image.shape

    trunc_mask = (image.sum(dim=0, keepdim=True) > 0).type(torch.float32) # 1xSxS        
    target_img = (image * mask).to(device).float()

    gaussians.initialize_pose() # initialize 0
    optimizer = optim.AdamW([gaussians._delta_R, gaussians._delta_T], lr=CFG.START_LR)
    lr_scheduler = CosineAnnealingWarmupRestarts(optimizer, CFG.MAX_STEPS,
                                                warmup_steps=CFG.WARMUP, 
                                                max_lr=CFG.START_LR, 
                                                min_lr=CFG.END_LR)
    iter_losses = list()
    for iter_step in range(CFG.MAX_STEPS):  # 100
        render_img = GS_Renderer(init_camera, gaussians, gaussian_PipeP, gaussian_BG)['render'] # * trunc_mask
        loss = 0.0

        loss += MSELoss(render_img, target_img).mean()
        # if CFG.USE_SSIM:
        #     loss  += (1 - SSIM_METRIC(render_img[None, ...], target_img[None, ...]))
        # if CFG.USE_MS_SSIM:
        #     loss += (1 - MS_SSIM_METRIC(render_img[None, ...], target_img[None, ...]))

        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()

        iter_losses.append(loss.item())
        if iter_step >= CFG.EARLY_STOP_MIN_STEPS:
            loss_grads = (torch.as_tensor(iter_losses)[1:] - torch.as_tensor(iter_losses)[:-1]).abs()
            if loss_grads[-CFG.EARLY_STOP_MIN_STEPS:].mean() < CFG.EARLY_STOP_LOSS_GRAD_NORM: # early stop the refinement
                break

        if True:
            print(iter_step, loss.item(), gaussians._delta_R)
            cat = (render_img.detach() + target_img).permute(1,2,0).cpu().numpy()
            save_path = 'debug/%04d.jpg' % iter_step
            os.makedirs(os.path.dirname(save_path), exist_ok=True)
            cv2.imwrite(save_path, cat / 2 * 255)
    gs3d_delta_RT = gaussians.get_delta_pose.squeeze(0).detach().cpu().numpy()

    outp = {
        'gs3d_delta_RT': gs3d_delta_RT,
        'iter_step': iter_step,
        'render_img': render_img,
    }

    if return_loss:
        sobel_x = torch.tensor([[1, 0, -1], [2, 0, -2], [1, 0, -1]], dtype=torch.float32).view(1, 1, 3, 3).repeat(1, 3, 1, 1)
        sobel_x = sobel_x.to(image.device)
        sobel_y = sobel_x.transpose(-2, -1)
        # Apply Sobel filter to the images
        query_sobel_h = torch_F.conv2d(image[None].float(), sobel_x, padding=0)
        query_sobel_v = torch_F.conv2d(image[None].float(), sobel_y, padding=0)
        rend_sobel_h = torch_F.conv2d(render_img[None], sobel_x, padding=0)
        rend_sobel_v = torch_F.conv2d(render_img[None], sobel_y, padding=0)
        edge_err = (query_sobel_h - rend_sobel_h).abs().mean() + (query_sobel_v - rend_sobel_v).abs().mean()
        outp['edge_err'] = edge_err

    return outp

image = obj_data['rgb_image']   # 3xSxS
target_size = image.shape[-2:]
FovX = obj_data['FovX']
FovY = obj_data['FovY']

gs3d_refined_errors = list()
gs3d_refined_RTs = init_pose.copy() # Kx4x4
for idx, init_RT in enumerate(init_pose):
    init_camera = GS_Camera(T=init_RT[:3, 3], R=init_RT[:3, :3].T,
                            FoVx=FovX, FoVy=FovY,
                            cx_offset=0, cy_offset=0,
                            image=image, colmap_id=0, uid=0, image_name='', gt_alpha_mask=None, data_device=device)

    ret_outp = GS_Refiner(image=image, mask=None, init_camera=init_camera, gaussians=gaussians, return_loss=True)
    gs3d_delta_RT = ret_outp['gs3d_delta_RT']
    refined_err = ret_outp['edge_err']
    gs3d_refined_RTs[idx] = init_RT @ gs3d_delta_RT
    gs3d_refined_errors.append(refined_err)
gs3d_refined_errors = torch.as_tensor(gs3d_refined_errors)
best_idx = gs3d_refined_errors.argmin().item()
gs3d_refined_RT = gs3d_refined_RTs[best_idx]

ret_outp['gs3d_refined_RT'] = gs3d_refined_RT

return ret_outp`