facebookresearch / pytorch3d

PyTorch3D is FAIR's library of reusable components for deep learning with 3D data
https://pytorch3d.org/
Other
8.68k stars 1.3k forks source link

Is there any sample code for camera position optimization + mesh prediction via silhouette rendering #407

Closed rezhv closed 3 years ago

rezhv commented 3 years ago

❓ camera position optimization + mesh prediction via silhouette rendering

albertotono commented 3 years ago

Dear @rezhv , Do these tutorials work for you? https://github.com/facebookresearch/pytorch3d/tree/master/docs/tutorials

Specifically https://github.com/facebookresearch/pytorch3d/blob/master/docs/tutorials/camera_position_optimization_with_differentiable_rendering.ipynb

rezhv commented 3 years ago

Dear @rezhv , Do these tutorials work for you? https://github.com/facebookresearch/pytorch3d/tree/master/docs/tutorials

Specifically https://github.com/facebookresearch/pytorch3d/blob/master/docs/tutorials/camera_position_optimization_with_differentiable_rendering.ipynb

Thanks for your quick reply @albertotono What I am looking for is something that incorporates camera optimization and predicting meshes using silhouette at the same time. I want a model that I can train with only pictures and no 3d object to use as ground truth. I have the following model which trains with only one camera view

class Model(nn.Module): def init(self, meshes, renderer, image_ref): super().init()

    self.meshes = meshes

    self.device = meshes.device

    self.renderer = renderer

    # Get the silhouette of the reference RGB image by finding all the non zero values. 
    # image_ref = torch.from_numpy((image_ref1.astype(np.float32)))

    self.register_buffer('image_ref', image_ref)
    verts_shape = meshes.verts_packed().shape
    self.deform_verts = nn.Parameter(torch.full(verts_shape, 0.0, device=device, requires_grad=True))
    # Create an optimizable parameter for the x, y, z position of the camera. 
    self.register_buffer('camera_position', torch.from_numpy(np.array([3.0,  6.9, +2.5], dtype=np.float32)).to(meshes.device))
    # self.camera_position = nn.Parameter(
    #     torch.from_numpy(np.array([3.0,  6.9, +2.5], dtype=np.float32)).to(meshes.device))

def forward(self):

    loss = {k: torch.tensor(0.0, device=device) for k in losses}
    update_mesh_shape_prior_losses(self.meshes, loss)

    # Render the image using the updated camera position. Based on the new position of the 
    # camer we calculate the rotation and translation matrices

    R = look_at_rotation(self.camera_position[None, :], device=self.device)  # (1, 3, 3)
    T = -torch.bmm(R.transpose(1, 2), self.camera_position[None, :, None])[:, :, 0]   # (1, 3)

    self.meshes = src_mesh.offset_verts(self.deform_verts)

    for j in np.random.permutation(num_views).tolist():
      image = self.renderer(meshes_world=self.meshes.clone(), R=R, T=T)
      predicted_silhouette = image[..., 3]
      # print(predicted_silhouette.detach().cpu().numpy().shape)
      # print(self.image_ref.detach().cpu().numpy().shape)
      loss_silhouette = ((predicted_silhouette.squeeze() - self.image_ref.squeeze()) ** 2).mean()
      loss["silhouette"] += loss_silhouette / num_views

    # Weighted sum of the losses
    sum_loss = torch.tensor(0.0, device=device)
    for k, l in loss.items():
        sum_loss += l * losses[k]["weight"]
        losses[k]["values"].append(l)

    # Print the losses
    # loop.set_description("total_loss = %.6f" % sum_loss)

    # Calculate the silhouette loss

    # loss = torch.sum((image[..., 3] - self.image_ref) ** 2)/30000
    # loss = sum_loss + loss
    return sum_loss, image`

However, the edge loss does not seem to go down.

rezhv commented 3 years ago

Unknown-2

gkioxari commented 3 years ago

First, @albertotono thank you for helping out and pointing to the tutorials! This is fantastic!!

@rezhv It's unclear to me whether you are bringing up a bug or a research question. Note that the issues of this repo should be focused on bugs or new feature requests. Of course, discussing solving research problems is important and maybe someone in here can help you with your optimization. However, the PyTorch3D does not have the capacity to consult for users' projects. We hope you find a solution to your problem and good luck with your project!

xhuan8 commented 3 years ago

@rezhv did you solved this issue, I follow the same procedures but the silhouette loss goes up