threestudio-project / threestudio

A unified framework for 3D content generation.
Apache License 2.0
6.17k stars 475 forks source link

Effstabledreamfusion #492

Open jadevaibhav opened 2 months ago

jadevaibhav commented 2 months ago

Efficient training of DreamFusion-like systems on higher-resolution images

I am working on a feature with Dreamfusion system(which can be extended to others). The basic idea is: to train using a higher-resolution image, we subsample pixels from it for NeRF rendering with a mask. Then we calculate the SDS loss at the original resolution image. The computational benefit is from a subsampling number of rays for NeRF training, while we train using higher resolution images (for a better visual model) in diffusion; resulting in roughly the same compute cost.

On testing using the demo prompt, using 128x128 image resolution and 64x64 subsampling for NeRF training, I get the following result. Screenshot 2024-07-25 at 4 33 10 PM I would like any feedback on potential issues with this idea, and how to improve results. I am looking forward to hearing from this community! @DSaurus @voletiv @bennyguo @thuliu-yt16

jadevaibhav commented 2 months ago

For comparison, with the efficient sampling method described above, I get ~30 min for training NeRF with 128x128 resolution (subsampled to 64x64). Without efficient sampling I get ~41 min of training duration (128x128 resolution), keeping all other parameters the same.

jadevaibhav commented 1 month ago

Hi @DSaurus, thanks for your approval! I have created a separate yaml config for this, so you just have to run:

python launch.py --config configs/dreamfusion-sd-eff.yaml --train  system.prompt_processor.prompt="a zoomed out DSLR photo of a baby bunny sitting on top of a stack of pancakes"

Here are the videos I generated, although they are not good quality... I am still investigating where the issue with generation quality is, and if this method can be extended to other generative systems.

https://github.com/user-attachments/assets/11bd2e65-8d86-48e2-b08d-8a0618589bea

https://github.com/user-attachments/assets/3f01f4a4-9ee0-4724-9363-a671ce19cfd0

DSaurus commented 1 month ago

Hi @jadevaibhav ,

Perhaps you could try to cache the rendering images without gradient first. Then, you sample some rays of this complete rendering image and update the corresponding pixels to do the SDS process. I think it is more robust for 3D generation.

jadevaibhav commented 1 month ago

@DSaurus, could you please explain what you mean here? If I understand correctly, caching multiple images before updating through SDS would be equivalent to directly generating bigger-resolution images. This defeats the purpose of generating a sub-sampled grid... My idea is essentially to take advantage of the continuous representation of 3D space learned through MLP. So at each iteration, we randomly sub-sample a set of ray directions, and over the complete optimization process, we learn at the original (bigger) resolution.

Here's my code of sub-sampling for clarity:

def mask_ray_directions(
    H: int,
    W:int,
    s_H:int,
    s_W:int
    ) -> Float[Tensor, "s_H s_W"]:
    """
    Masking the (H,W) image to (s_H,s_W), for efficient training at higher resolution image.
    pixels from (s_H,s_W) are sampled more (1-aspect_ratio) than outside pixels(aspect_ratio).
    the masking is deferred to before calling get_rays().
    """
    indices_all = torch.meshgrid(
        torch.arange(W, dtype=torch.float32) ,
        torch.arange(H, dtype=torch.float32) ,
        indexing="xy",
    )

    mask = torch.zeros(H,W, dtype=torch.bool)
    mask[(H-s_H)//2 : H - math.ceil((H-s_H)/2),(W-s_W)//2 : W - math.ceil((W-s_W)/2)] = True

    in_ind_1d = (indices_all[0]+H*indices_all[1])[mask]
    out_ind_1d = (indices_all[0]+H*indices_all[1])[torch.logical_not(mask)]
    ### tried using 0.5 p ratio of sampling inside vs outside, as smaller area already 
    ### leads to more samples inside anyways

    p = 0.5#(s_H*s_W)/(H*W)
    select_ind = in_ind_1d[
        torch.multinomial(
        torch.ones_like(in_ind_1d)*(1-p),int((1-p)*(s_H*s_W)),replacement=False)]
    select_ind = torch.concatenate(
        [select_ind, out_ind_1d[torch.multinomial(
            torch.ones_like(out_ind_1d)*(p),int((p)*(s_H*s_W)),replacement=False)]
        ],
        dim=0).to(dtype=torch.int).view(s_H,s_W)

    return select_ind
DSaurus commented 1 month ago

@jadevaibhav Sure, my idea is to use these cached images multiple times, and each time you can apply your sub-sampler to update these images. If my understanding is correct, the current mask sub-sampler will render images that are not complete. However, diffusion models like Stable Diffusion are not designed to recover these incomplete images. I think this is the reason why the current mask sub-sampler leads to unstable results.

jadevaibhav commented 1 month ago

@DSaurus the sub-sampler is used on generated directions, so we only pass selected directions to NeRF. And while calculating SDS loss, I pass the original resolution image with rendered color filled at given indices, and 0 elsewhere. I also believe that diffusion is unable to recover the incomplete image. Rather than creating an incomplete image, I am thinking of doing an interpolation using these rendered colors. This way, even the gradients are not being wasted. What do you think? I will be happy to continue the caching discussion on Discord if you want. Also, should we merge the current version in the meantime?

jadevaibhav commented 2 weeks ago

Hi @DSaurus thanks for approving the PR! I don't have the write access, so could you please merge?

I looked into the "interpolation", but currently there is no way to do it with randomly sampled positions. I was looking into the grid_sample() method, but I can't define a transformation or mapping from the original resolution coordinate system to the sampled grid coordinates. I am now experimenting with uniform subsampling, with a random offset for the top-left grid corner.

jadevaibhav commented 4 days ago

I finished the new experiment, and it works better than before! The training time is still the same (~33 mins)!

Screenshot 2024-09-22 at 8 53 12 PM

https://github.com/user-attachments/assets/37e1d5fa-578c-4420-a3c8-5408d6a74e17