spla-tam / SplaTAM

SplaTAM: Splat, Track & Map 3D Gaussians for Dense RGB-D SLAM (CVPR 2024)
https://spla-tam.github.io/
BSD 3-Clause "New" or "Revised" License
1.58k stars 174 forks source link

Pose Backward Propagation #64

Closed fangli333 closed 8 months ago

fangli333 commented 9 months ago

Hi,

I am looking at how you do backward propagation and I have some questions. In 'transform_to_frame' function, I think you set all camera parameters' requires_grad = True. But in splatam.py lines like the following:

rendervar['means2D'].retain_grad()
im, radius, _, = Renderer(raster_settings=curr_data['cam'])(**rendervar)
variables['means2D'] = rendervar['means2D'] 

I am curious about although you give 'cam' gradient, but actually inside the renderer, they do not do backward propagation to anything related to 'cam'. Can you help explain which code lines are used for backward propagation on cams? and explain why? Thank you. it might be a stupid question. But I am reproducing your work, and cannot know where you do backward propagation on cams things.

yaochengrong commented 9 months ago

to my knowledge, in pytorch framework, as long as "cam" is defined as require true gradient, it will automatically compute the gradient of "cam" , but with some preconditions, e.g. the compute graph is straightforward. hopefully this can be helpful

cuijh26 commented 9 months ago

Hi,

I am looking at how you do backward propagation and I have some questions. In 'transform_to_frame' function, I think you set all camera parameters' requires_grad = True. But in splatam.py lines like the following:

rendervar['means2D'].retain_grad()
im, radius, _, = Renderer(raster_settings=curr_data['cam'])(**rendervar)
variables['means2D'] = rendervar['means2D'] 

I am curious about although you give 'cam' gradient, but actually inside the renderer, they do not do backward propagation to anything related to 'cam'. Can you help explain which code lines are used for backward propagation on cams? and explain why? Thank you. it might be a stupid question. But I am reproducing your work, and cannot know where you do backward propagation on cams things.

I have the same question, can you solve it?

fangli333 commented 9 months ago

actually no, the authors did not discuss it

cuijh26 commented 9 months ago

actually no, the authors did not discuss it

so sad

Nik-V9 commented 9 months ago

Hi, Thanks for your interest in our work! You are correct that we don't receive backward gradients for the camera through the current 3DGS rasterizer. When you render a 3D Gaussian Map in the official 3DGS rasterizer, it will convert Gaussians in the world frame to the passed camera frame and render them. Here's how we receive the gradients through PyTorch instead:

  1. We define only one 3DGS camera for the first frame (this is the canonical or rendering camera, usually an Identity matrix for SplaTAM because the first frame is the world origin). Also, our Gaussian Map is always in the world frame. https://github.com/spla-tam/SplaTAM/blob/a0bda58dd6fbf3e2ad31e40adc48514923bec4c0/scripts/splatam.py#L187
  2. The subsequent camera poses are defined as a global rigid rotation and translation on the world frame Gaussians. https://github.com/spla-tam/SplaTAM/blob/a0bda58dd6fbf3e2ad31e40adc48514923bec4c0/scripts/splatam.py#L145
  3. When we want to render the world frame Gaussian map from a particular pose, we apply the rotation and translation on the Gaussians and pass them to the original 3DGS rasterizer. https://github.com/spla-tam/SplaTAM/blob/a0bda58dd6fbf3e2ad31e40adc48514923bec4c0/utils/slam_helpers.py#L216

Once you understand the above process, you can see that in our loss computation, we do precisely the same: Transform Gaussians to camera frame, Initialize Render Variables, Rasterize, Compute Loss, and Backpropagate. https://github.com/spla-tam/SplaTAM/blob/a0bda58dd6fbf3e2ad31e40adc48514923bec4c0/scripts/splatam.py#L223

Also, please see a similar issue: https://github.com/spla-tam/SplaTAM/issues/28

This issue also might be of further interest to understand our other gradient computation tricks: https://github.com/spla-tam/SplaTAM/issues/54#issuecomment-1887381644

Nik-V9 commented 8 months ago

Closing due to inactivity.