google-research / multinerf

A Code Release for Mip-NeRF 360, Ref-NeRF, and RawNeRF
Apache License 2.0
3.62k stars 341 forks source link

Questions on calculating the distance_mean (depth) #103

Open cwchenwang opened 1 year ago

cwchenwang commented 1 year ago

Thanks for the excellent work of mipnerf360. I am a little confused about the depth loss in mipnerf360. The current way of calculating depth is: rendering['distance_mean'] = ( jnp.clip( jnp.nan_to_num(jnp.exp(expectation(jnp.log(t_mids))), jnp.inf), tdist[..., 0], tdist[..., -1])), which is equal to t_mids[0]^weights[0] t_mids[1]^weights[1] ....

However, I find the depth in original NeRF seems to be (weights * t_mids). Why there is this difference? Also, when computing the depth losses, depths are converted into 1/(1+rendering['distance_mean']), why not just using mse of depth maps?

Looking forward to your reply!

jonbarron commented 1 year ago

Ah, looks like we're computing the geometric mean instead of the arithmetic mean there. I think this was done to prevent numerical issues when t is very large, and I remember it being very similar to the arithmetic mean in other circumstances. Probably should have noted that in the paper, sorry.

The depth losses are computed in terms of disparity, which are usually easier to work with than raw depths: it's easier to reason about disparity=0 than depth=infinity.