threestudio-project / threestudio

A unified framework for 3D content generation.
Apache License 2.0
6.32k stars 480 forks source link

Simplify the formulation of SDS loss #266

Closed ashawkey closed 1 year ago

ashawkey commented 1 year ago

We can use a even simpler reparameterization to achieve the SDS loss:

# current:
target = (latents - grad).detach()
# d(loss)/d(latents) = latents - target = latents - (latents - grad) = grad
loss_sds = 0.5 * F.mse_loss(latents, target, reduction="sum") / batch_size

# new:
# d(loss)/d(latents) = grad
loss_sds = (latents * grad.detach()).sum()

This is originally proposed in https://github.com/ashawkey/stable-dreamfusion/issues/335 (credits to @Xallt). Maybe we can implement it here too?

bennyguo commented 1 year ago

Looks good! Could you verify that the magnitude of the gradient remains the same (especially for data.batch_size > 1?)

ashawkey commented 1 year ago

Oh, I didn't realize that. The magnitude of this formulation is not as meaningful as the original one, it involves latents in the final value instead of only the magnitude of grad. If we want the value of loss to be informative, I guess we should keep using the original one.

Xallt commented 1 year ago

Yeah this MSE formulation is also cool, and with an even more meaningful loss value. I'd actually prefer this one to the one I suggested

voletiv commented 1 year ago

I used to use the latents * grad.detach() formulation before (~8 months ago) for my own experiments, I find the threestudio loss to be more effective/faster in training. But haven't done extensive experiments on it

Xallt commented 1 year ago

@voletiv Faster in training? Why do you think that happens? Kernel for the F.mse_loss gradient computation is faster than for the multiplication?

voletiv commented 1 year ago

@Xallt By "faster" I don't mean per compute, I mean iterations from initialization to a good final rendering. Optimization seems more effective (in my experiments). Multiple factors contribute to this, possibly unrelated to the loss as such : the combination of implicit-volume and original loss works very well, but SDF and original loss is not as effective (in my experiments for 2D->3D).

bennyguo commented 1 year ago

I think @Xallt makes a very good point. Although latents * grad.detach() is easier to understand, the MSE formulation gives meaningful gradient magnitude. So I think we'll stick with the MSE formulation then.