yccyenchicheng / AutoSDF

232 stars 30 forks source link

Reconstruction loss in supp #19

Open Kitsunetic opened 2 years ago

Kitsunetic commented 2 years ago

Hi, Thank you for great work.

I have question about the Eq.1 of supp. $\mathcal L_\text{VQ-VAE}=-\log p(X|\mathbf Z) + || \text{sg}[\hat{\mathbf Z}]-\mathbf Z||^2_2+|| \hat{\mathbf Z} - \text{sg}[\mathbf Z]||^2_2$

But in here, it seems that it is calculating MAE between predicted SDF and GT SDF. Thus, I think $\mathcal L_\text{VQ-VAE}=|| \hat{\mathbf X}-\mathbf X ||_1 + || \text{sg}[\hat{\mathbf Z}]-\mathbf Z||^2_2+|| \hat{\mathbf Z} - \text{sg}[\mathbf Z]||^2_2$ is correct. Do I right?

Additionally the $\beta$ hyper parameter of Vector Quantizer is $1.0$. So could VQ-loss without stop-gradient be same with the proposed loss function? $|| \hat{\mathbf Z} - \mathbf Z||^2_2 = || \text{sg}[\hat{\mathbf Z}]-\mathbf Z||^2_2+|| \hat{\mathbf Z} - \text{sg}[\mathbf Z]||^2_2$

yccyenchicheng commented 1 year ago

Hi @Kitsunetic,

Sorry for the late reply!

yes I think the equation 2 you have is correct. The negative log likelihood term is the reconstruction loss for GT SDF and predicted SDF.

I don’t quite get your final question. You need the stop gradient as the code book loss is not differientiable. If you look at the original VQVAE paper this is call the gradient “straight through gradient” (section 3.2 at https://arxiv.org/pdf/1711.00937v2.pdf). So I think they are different.

hmax233 commented 1 year ago

I have the same question about the loss function. As in original VQ-VAE scripts,the reconstruction loss is backward twice because the gradient stops at $\hat{Z}$ to ${Z}$ . So reconstruction loss can't update the Encoder. To solve this problem, the operation in original VQ-VAE is to directly pass the gradient of ${Z}$ to $\hat{Z}$ and backward $\hat{Z}$ to update the Encoder. But what I seen in your scripts, the reconstruction loss is backward once, which can't update Encoder. Here is the backward script of the original VQ-VAE `

            total_loss = recon_loss + sg_z_and_embd_loss + self.beta*z_and_sg_embd_loss
            total_loss.backward(retain_graph=True)
            Z_enc.backward(self.model.grad_for_encoder)

`