jjihwan / SV3D-fine-tune

Fine-tuning code for SV3D
MIT License
38 stars 0 forks source link

Can you provide the weight of SV3D encoder? #3

Open zhaosheng-thu opened 1 month ago

zhaosheng-thu commented 1 month ago

Thanks for your work. You'v mentioned that the video_latent.pt is the video latent encoded by SV3D encoder. So where could I find the weight of this encoder? It seems that the weight in sv3d_p.safetensors only includes the vae (first_stage_model) decoder weights. Since you have already utilized the first_stage_model.decode_first_stage() during log and inference, could you please tell how can I find the encoder weights? Thanks a lot!

jjihwan commented 1 month ago

You're right. sv3d_p.safetensors does not contain the encoder since the authors preprocessed the encoder phase before training, as mentioned in the SV3D paper.

Therefore, as you can see here, you might have to use SVD's encoder. Note that you should save the latents just before 'regularization', that have channel size 8, not 4. Thank you for informing me. We'll add the description soon.

zhaosheng-thu commented 1 month ago

You're right. sv3d_p.safetensors does not contain the encoder since the authors preprocessed the encoder phase before training, as mentioned in the SV3D paper.

Therefore, as you can see here, you might have to use SVD's encoder. Note that you should save the latents just before 'regularization', that have channel size 8, not 4. Thank you for informing me. We'll add the description soon.

Thanks for your interpretation. You'v mentioned that I need to save the latents before regularization, I also noticed that in the sv3d initial repository, in the config file the target of regularizer_config was set to DiagonalGaussianRegularizer, so I'm wondering if I replaceDiagonalGaussianRegularizer by torch.nn.Identity, whether I can just save the latents directly generated by vae in this situation?

vae = AutoencoderKLTemporalDecoder.from_pretrained("")
torch.save(vae.encode(img).latent_dist.mode()*0.18215,"")

In DiagonalGaussianRegularizer, the func sample is implemented by

x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device )

Not just the mean. Are there any difference? Thanks a lot for your former prompt responses!

jjihwan commented 1 month ago

To my knowledge, DiagonalGaussianRegularizer is needed to train VAE in a reparameterization manner. However, although we do not train VAE in SV3D training, I used the regularizer to directly follow the config file from the original SV3D repository.

Actually, I also wondered whether I should use the regularizer or just use torch.nn.Identity when I implemented it, but I decided to follow the config file 🥲.

In some sense, I think the regularizer can introduce some stochasticity in the training phase. You can try it, and please inform me if you gain any insights from that.