threestudio-project / threestudio

A unified framework for 3D content generation.
Apache License 2.0
6.34k stars 480 forks source link

Effstabledreamfusion #492

Closed jadevaibhav closed 1 month ago

jadevaibhav commented 4 months ago

Efficient training of DreamFusion-like systems on higher-resolution images

I am working on a feature with Dreamfusion system(which can be extended to others). The basic idea is: to train using a higher-resolution image, we subsample pixels from it for NeRF rendering with a mask. Then we calculate the SDS loss at the original resolution image. The computational benefit is from a subsampling number of rays for NeRF training, while we train using higher resolution images (for a better visual model) in diffusion; resulting in roughly the same compute cost.

On testing using the demo prompt, using 128x128 image resolution and 64x64 subsampling for NeRF training, I get the following result. Screenshot 2024-07-25 at 4 33 10 PM I would like any feedback on potential issues with this idea, and how to improve results. I am looking forward to hearing from this community! @DSaurus @voletiv @bennyguo @thuliu-yt16

jadevaibhav commented 3 months ago

For comparison, with the efficient sampling method described above, I get ~30 min for training NeRF with 128x128 resolution (subsampled to 64x64). Without efficient sampling I get ~41 min of training duration (128x128 resolution), keeping all other parameters the same.

jadevaibhav commented 2 months ago

Hi @DSaurus, thanks for your approval! I have created a separate yaml config for this, so you just have to run:

python launch.py --config configs/dreamfusion-sd-eff.yaml --train  system.prompt_processor.prompt="a zoomed out DSLR photo of a baby bunny sitting on top of a stack of pancakes"

Here are the videos I generated, although they are not good quality... I am still investigating where the issue with generation quality is, and if this method can be extended to other generative systems.

https://github.com/user-attachments/assets/11bd2e65-8d86-48e2-b08d-8a0618589bea

https://github.com/user-attachments/assets/3f01f4a4-9ee0-4724-9363-a671ce19cfd0

DSaurus commented 2 months ago

Hi @jadevaibhav ,

Perhaps you could try to cache the rendering images without gradient first. Then, you sample some rays of this complete rendering image and update the corresponding pixels to do the SDS process. I think it is more robust for 3D generation.

jadevaibhav commented 2 months ago

@DSaurus, could you please explain what you mean here? If I understand correctly, caching multiple images before updating through SDS would be equivalent to directly generating bigger-resolution images. This defeats the purpose of generating a sub-sampled grid... My idea is essentially to take advantage of the continuous representation of 3D space learned through MLP. So at each iteration, we randomly sub-sample a set of ray directions, and over the complete optimization process, we learn at the original (bigger) resolution.

Here's my code of sub-sampling for clarity:

def mask_ray_directions(
    H: int,
    W:int,
    s_H:int,
    s_W:int
    ) -> Float[Tensor, "s_H s_W"]:
    """
    Masking the (H,W) image to (s_H,s_W), for efficient training at higher resolution image.
    pixels from (s_H,s_W) are sampled more (1-aspect_ratio) than outside pixels(aspect_ratio).
    the masking is deferred to before calling get_rays().
    """
    indices_all = torch.meshgrid(
        torch.arange(W, dtype=torch.float32) ,
        torch.arange(H, dtype=torch.float32) ,
        indexing="xy",
    )

    mask = torch.zeros(H,W, dtype=torch.bool)
    mask[(H-s_H)//2 : H - math.ceil((H-s_H)/2),(W-s_W)//2 : W - math.ceil((W-s_W)/2)] = True

    in_ind_1d = (indices_all[0]+H*indices_all[1])[mask]
    out_ind_1d = (indices_all[0]+H*indices_all[1])[torch.logical_not(mask)]
    ### tried using 0.5 p ratio of sampling inside vs outside, as smaller area already 
    ### leads to more samples inside anyways

    p = 0.5#(s_H*s_W)/(H*W)
    select_ind = in_ind_1d[
        torch.multinomial(
        torch.ones_like(in_ind_1d)*(1-p),int((1-p)*(s_H*s_W)),replacement=False)]
    select_ind = torch.concatenate(
        [select_ind, out_ind_1d[torch.multinomial(
            torch.ones_like(out_ind_1d)*(p),int((p)*(s_H*s_W)),replacement=False)]
        ],
        dim=0).to(dtype=torch.int).view(s_H,s_W)

    return select_ind
DSaurus commented 2 months ago

@jadevaibhav Sure, my idea is to use these cached images multiple times, and each time you can apply your sub-sampler to update these images. If my understanding is correct, the current mask sub-sampler will render images that are not complete. However, diffusion models like Stable Diffusion are not designed to recover these incomplete images. I think this is the reason why the current mask sub-sampler leads to unstable results.

jadevaibhav commented 2 months ago

@DSaurus the sub-sampler is used on generated directions, so we only pass selected directions to NeRF. And while calculating SDS loss, I pass the original resolution image with rendered color filled at given indices, and 0 elsewhere. I also believe that diffusion is unable to recover the incomplete image. Rather than creating an incomplete image, I am thinking of doing an interpolation using these rendered colors. This way, even the gradients are not being wasted. What do you think? I will be happy to continue the caching discussion on Discord if you want. Also, should we merge the current version in the meantime?

jadevaibhav commented 2 months ago

Hi @DSaurus thanks for approving the PR! I don't have the write access, so could you please merge?

I looked into the "interpolation", but currently there is no way to do it with randomly sampled positions. I was looking into the grid_sample() method, but I can't define a transformation or mapping from the original resolution coordinate system to the sampled grid coordinates. I am now experimenting with uniform subsampling, with a random offset for the top-left grid corner.

jadevaibhav commented 2 months ago

I finished the new experiment, and it works better than before! The training time is still the same (~33 mins)!

Screenshot 2024-09-22 at 8 53 12 PM

https://github.com/user-attachments/assets/37e1d5fa-578c-4420-a3c8-5408d6a74e17

DSaurus commented 1 month ago

@jadevaibhav LGTM! Could you please create a file named eff_dreamfusion.py in the system folder and put your current code into this file?

jadevaibhav commented 1 month ago

Sure!

jadevaibhav commented 1 month ago

Done! Please review @DSaurus

jadevaibhav commented 1 month ago

Thanks! I would like to contribute more, is there any new papers/ implementations we're looking at?

DSaurus commented 1 month ago

@jadevaibhav I think it would be great if you are interested in implementing Wonder3D and its following papers, which could generate 3D objects in seconds.