Lakonik / SSDNeRF

[ICCV 2023] Single-Stage Diffusion NeRF
https://lakonik.github.io/ssdnerf/
MIT License
430 stars 23 forks source link

Question about triplane prediction & optimization #29

Open rfeinman opened 9 months ago

rfeinman commented 9 months ago

Thank you so much for an awesome code library!

I am trying to train a neural network to predict triplane codes from a reference image view of an object. I am using your triplane-nerf library for the rendering and it works pretty well but I am seeing some odd pixelation & artifacts even after training to convergence. Below is a very brief code description of the optimization procedure that I follow during training. The parameters of decoder and predictor_net are optimized. Am I doing anything wrong here? I've included a visualization of the predicted (rendered) image vs. target image at the bottom of this message.

I noticed that the output density_bitfield from nerf.get_density does not have grad. Don't we need gradients to flow through the density MLP in order to facilitate proper training? Is there a way to do this with grad?

from lib.models.autodecoders.base_nerf import BaseNeRF
from lib.models.decoders.triplane_decoder import TriPlaneDecoder
from lib.core.utils.nerf_utils import get_cam_rays

decoder = TriPlaneDecoder(
    base_layers=[3 * 6, 64], 
    density_layers=[64, 1],
    color_layers=[64, 3],
    dir_layers=[16, 64],
)

nerf = BaseNeRF(code_size=(3, 6, 64, 64), grid_size=64)

def render(code, density_bitfield, h, w, intrinsics, poses):
    rays_o, rays_d = get_cam_rays(poses, intrinsics, h=h, w=w)

    batch_size, height, width, channels = rays_o.shape
    rays_o = rays_o.view(batch_size, height * width, channels)
    rays_d = rays_d.view(batch_size, height * width, channels)        

    outputs = decoder(rays_o, rays_d, code, density_bitfield, nerf.grid_size)

    image = outputs['image'] + nerf.bg_color * (1 - outputs['weights_sum'].unsqueeze(-1))

    return image.reshape(batch_size, h, w, 3)

for _ in range(iterations):
    # reference_img is size 128 x 128
    triplane_code = predictor_net(reference_img, reference_intrinsics, reference_poses)

    _, density_bitfield = nerf.get_density(
        decoder, triplane_code, cfg=dict(density_thresh=0.1, density_step=16))

    pred_img = render(
        triplane_code, density_bitfield, h=128, w=128, 
        intrinsics=target_intrinsics, poses=target_poses)

    loss = (pred_img - target_img).pow(2).mean()
    loss.backward()
    #optimizer.step() ... etc

prediction vs. target:

prediction target

Lakonik commented 9 months ago

Hi! The density grid is for occupancy-based pruning, which is not part of the gradient graph by design. Since you are using triplanes with a resolution of 64, it could be possible that there's not enough capacity to capture the full details of the target. However, this problem could be mitigated by using LPIPS loss, which will be a new feature of this codebase in an uncoming updated release.

rfeinman commented 8 months ago

Hi @Lakonik - thanks for the helpful reply to my question! It seems like occupancy-based pruning is designed primarily for the discrete NeRF problem where you have some finite set of scenes and you can maintain a grid state for each (the density_bitfield) that is updated at some interval.

I'm wondering: what is your suggested approach for raymarching in a setting where the number of NeRFs is infinite? For example, a setting where we predict a triplane nerf from an image like LRM. Should the density_bitfield be computed from scratch for each prediction? Or maybe just use some dummy value of the bitfield?

Lakonik commented 8 months ago

density_bitfield can be computed from scratch (by updating for multiple steps at once, which costs some time) or simply be filled with 255 (no pruning). If you render multiple views at once then compute density_bitfield could be faster, otherwise just use the filled dummy.