NJU-3DV / Relightable3DGaussian

[ECCV2024] Relightable 3D Gaussian: Real-time Point Cloud Relighting with BRDF Decomposition and Ray Tracing
https://nju-3dv.github.io/projects/Relightable3DGaussian/
Other
376 stars 24 forks source link

A question about visibility loss #20

Open yuyuyu223 opened 3 months ago

yuyuyu223 commented 3 months ago

Why wasn't normalization applied to rand_rays_d when calculating the visibility loss? I noticed that normalized direction were used for spherical harmonic calculations in other attributes.

if opt.lambda_visibility > 0:
        num = 10000
        means3D = pc.get_xyz
        visibility = pc.get_visibility
        normal = pc.get_normal
        opacity = pc.get_opacity

        rand_idx = torch.randperm(means3D.shape[0])[:num]
        rand_visibility_shs_view = visibility.transpose(1, 2).view(-1, 1, 4 ** 2)[rand_idx]
        rand_rays_o = means3D[rand_idx]
        rand_rays_d = torch.randn_like(rand_rays_o)
        cov_inv = pc.get_inverse_covariance()
        rand_normal = normal[rand_idx]
        mask = (rand_rays_d * rand_normal).sum(-1) < 0
        rand_rays_d[mask] *= -1
        sample_sh2vis = eval_sh(3, rand_visibility_shs_view, rand_rays_d)
        sample_vis = torch.clamp(sample_sh2vis + 0.5, 0.0, 1.0)
        raytracer = RayTracer(means3D, pc.get_scaling, pc.get_rotation)
        trace_results = raytracer.trace_visibility(
            rand_rays_o, rand_rays_d, means3D,
            cov_inv, opacity, normal)

        rand_ray_visibility = trace_results["visibility"]
        loss_visibility = F.l1_loss(rand_ray_visibility, sample_vis)
        tb_dict["loss_visibility"] = loss_visibility.item()
        loss = loss + opt.lambda_visibility * loss_visibility