daniel03c1 / masked_wavelet_nerf

MIT License
79 stars 5 forks source link

About detach() function #10

Closed xuyaojian123 closed 11 months ago

xuyaojian123 commented 11 months ago

Thank you great work!But I have some code issues that I want to ask you . I see a lot of detach() in code like this:

if self.use_mask:
        mask = torch.sigmoid(self.density_plane_mask[idx])
        plane = (plane * (mask >= 0.5) - plane * mask).detach() \
              + plane * mask
        mask = torch.sigmoid(self.density_line_mask[idx])
        line = (line * (mask >= 0.5) - line * mask).detach() \
             + line * mask

Is this different from the code below?

if self.use_mask:
        mask = torch.sigmoid(self.density_plane_mask[idx])
        plane = (plane * (mask >= 0.5)).detach() 
        mask = torch.sigmoid(self.density_line_mask[idx])
        line = (line * (mask >= 0.5)).detach()

In addition, I don’t know why the variables plane and line should be separated from the calculation graph. Looking forward to your reply. Thank you!

daniel03c1 commented 11 months ago

The issues you have pointed out are actually implementation issues. More precisely, detaches are for the straight-through-estimator technique. If you use detach as "(plane (mask >= 0.5)).detach()", you cannot update masks and planes properly. The reason for using detach is to make backward and forward passes different. During a forward pass, you can ignore "detach", and the code will be executed as follows: "plane (mask >= 0.5) - plane mask + plane mask", which is equal to "plane (mask >= 0.5)". During the backward pass, "detach" stops gradients from flowing, and the code will only get the gradient of "plane mask".

xuyaojian123 commented 11 months ago

The issues you have pointed out are actually implementation issues. More precisely, detaches are for the straight-through-estimator technique. If you use detach as "(plane (mask >= 0.5)).detach()", you cannot update masks and planes properly. The reason for using detach is to make backward and forward passes different. During a forward pass, you can ignore "detach", and the code will be executed as follows: "plane (mask >= 0.5) - plane mask + plane mask", which is equal to "plane (mask >= 0.5)". During the backward pass, "detach" stops gradients from flowing, and the code will only get the gradient of "plane mask".

Thanks!

xuyaojian123 commented 11 months ago

The issues you have pointed out are actually implementation issues. More precisely, detaches are for the straight-through-estimator technique. If you use detach as "(plane (mask >= 0.5)).detach()", you cannot update masks and planes properly. The reason for using detach is to make backward and forward passes different. During a forward pass, you can ignore "detach", and the code will be executed as follows: "plane (mask >= 0.5) - plane mask + plane mask", which is equal to "plane (mask >= 0.5)". During the backward pass, "detach" stops gradients from flowing, and the code will only get the gradient of "plane mask".

I thought about it again. If I use

if self.use_mask:
        mask = torch.sigmoid(self.density_plane_mask[idx])
        plane = (plane * (mask >= 0.5))
        mask = torch.sigmoid(self.density_line_mask[idx])
        line = (line * (mask >= 0.5))

Can it have the same effect as yours? Whether it can correctly update masks, planes, lines? I'm a beginner in torch coding.

xuyaojian123 commented 11 months ago

plane * mask

we can't use (mask >= 0.5) to update gradient,because it’s piecewise function? we want to use continuous function plane * mask to update gradient. Is that so?

daniel03c1 commented 11 months ago

Using that code may lead to improper gradient calculation, as the operation (mask >= 0.5) results in a binary output, for which gradients are not defined.