onpix / LLNeRF

[ICCV2023] Lighting up NeRF via Unsupervised Decomposition and Enhancement
https://whyy.site/paper/llnerf
MIT License
82 stars 9 forks source link

About smooth loss #16

Closed XLR-man closed 2 months ago

XLR-man commented 5 months ago

Hi, thanks for your great work!

I am trying to replicate your work using pytorch framework. Following your form of code, I get preliminary results: image

It is obvious that there is an error in obtaining the enhanced image part. I initially found that my understanding was wrong in the smooth loss part. In your code, you first define three types of rays to calculate the loss by taking the raw pixels, horizontal and vertical pixels. So in my understanding, the ray origin or ray direction dimension of each batch should be (batch, 3,3)?

However, in my pytorch code, I defined the ray origin or ray direction dimension of each batch as (batch, 3), so I ignored the middle dimension, that is, I didn't get the three types of rays, which made smooth loss invalid.

So how should I fix it? Simply add a grid of extra pixels to get three types of rays? However, in the input network, its dimension needs to be changed into two-dimensional (batch* 3,3) to pass through the linear layers, and finally the reshape into (batch, 3,3) to calculate the smooth loss. Can such a method successfully obtain enhanced images?

I will appreciate for your help!

onpix commented 5 months ago

According to https://pytorch.org/docs/stable/generated/torch.nn.Linear.html#torch.nn.Linear, I think passing a tensor-shaped [batcisize, 3, 3] to an MLP is fine.

XLR-man commented 5 months ago

Thank you for your reply! I have tried you suggestion,however,i still can't get pleasurable results. Unlike the above, the enhanced image will show large area artifacts in grass green color. I'm pretty confident that the training for low-light image reconstruction will be fine,because after training the loss function without enhancement loss 100k iterations, the psnr between the rendered low-light image and the original image is about 40dB, and the rendered low-light image is as shown above, and the enhanced image output by the enhancement network is: image000 Therefore, I think there is a problem with the enhancement process The enhancement functions are adapted from your code to the torch framework:

def gray_loss(rgb,ref=None):
    assert rgb.shape[-1] == 3

    weight1 = weight2 = 1
    assert ref is not None
    weight2 = torch.var(ref, dim=-1, keepdim=True) + 0.5

    diffs = (rgb - torch.roll(rgb, shifts=1, dims=-1)) ** 2  # shape: [..., 3]
    return torch.sqrt(diffs.sum(dim=-1, keepdim=True) / 3 / weight1 / weight2).mean()

class Exp_loss_global(nn.Module):
    def __init__(self, mean_val=0.5):
        super(Exp_loss_global, self).__init__()
        self.mean_val = mean_val

    def forward(self, x):
        x = torch.mean((x.mean(dim=-1)- self.mean_val) ** 2)
        loss = x
        return loss

def ltv_loss(L_e, L,  beta=1.5, alpha=2, eps=1e-4):

    L = torch.log(L + eps)
    pix_L, right_pix_L, down_pix_L = L[:, 0, ...], L[:, 1, ...], L[:, 2, ...]
    pix_Le, right_pix_Le, down_pix_Le = L_e[:, 0, ...], L_e[:, 1, ...], L_e[:, 2, ...]

    dx_L = pix_L - right_pix_L
    dy_L = pix_L - down_pix_L
    dx_Le = pix_Le - right_pix_Le
    dy_Le = pix_Le - down_pix_Le

    ltv_x = (beta * dx_Le ** 2) / (dx_L ** alpha + eps)
    ltv_y = (beta * dy_Le ** 2) / (dy_L ** alpha + eps)

    return ((ltv_x + ltv_y) / 2).mean()

And the code to get the enhancement loss is:

r_norm_coarse = r_coarse / r_coarse.max()
            r_norm_fine = r_fine / r_fine.max()

            loss_control = helper.Exp_loss_global(mean_val=self.eta)(rgb_coarse_light) + helper.Exp_loss_global(
                mean_val=self.eta)(rgb_fine_light)  # twilight

            loss_cc = helper.gray_loss(rgb_coarse_light, r_norm_coarse) + helper.gray_loss(rgb_fine_light,
                                                                                           r_norm_fine)  # Gray world

            temp_coarse = torch.broadcast_to(v_coarse, r_coarse.shape)
            temp_fine = torch.broadcast_to(v_fine, r_fine.shape)
            gamma_ref_coarse, gamma_ref_fine = temp_coarse, temp_fine

            loss_smooth = helper.ltv_loss(gamma_coarse, gamma_ref_coarse) + helper.ltv_loss(gamma_fine, gamma_ref_fine)
            loss_smooth += helper.ltv_loss(alpha_coarse, v_coarse) + helper.ltv_loss(alpha_fine, v_fine)

            loss = loss1 + loss0 + 0.1 * loss_control + 0.1 * loss_cc + 0.1 * loss_smooth

Could you please help me see where I misunderstood and modified incorrectly, or what else should I pay attention to. I will appreciate for your help!@onpix

XLR-man commented 5 months ago

By the way, I found when add enhancement losses, after roughly 1k iterations, the enhancement losses will be nan. And I also found that the value of enhanced image very small , such as about -2.34e-5. I don't know if that's going to make a difference

XLR-man commented 5 months ago

Update!

Now the main problem is that the reconstruction based on r and v works well, but once the enhanced loss function is added to the constraint, it will fail, not only making the original reconstructed structure destroyed, but also the enhanced effect is very poor, as shown in the following figure:

image000

May I ask what is the possible reason for this? I hope you can answer it

XLR-man commented 5 months ago

When I remove loss_cc for training, the enhancement result is: image001

In addition, is your model trained on a single gpu and does not support multi-GPU training?

I found that in my modified pytorch framework, the reason may be because some parameters in the model are not involved in gradient backpropagation.

Because the idea is to train the reconstruction first and then train the enhancement.

When training the reconstruction part, will the parameters of your enhancement network be involved in the gradient calculation and update?

Because my code will give an error, it is really because when I calculate the first loss, there are some modules of the model whose parameters are not involved in the gradient backpropagation. I have to make these layers participate in the loss1 calculation, which means I have to make the output of the final layer participate in the loss1 calculation. However, the actual requirement is that the final output does not need to participate in the calculation of loss1. Right?

So, I came up with this idea: Since you have to make all my parameters participate in the gradient backpropagation of the loss, why don't I just make the gradient of the layers from the middle layer to the final layer (including the final layer) be 0? As a result, the gradient of the parameters of these layers is 0, and updating them is equivalent to not updating these parameters. Does this affect anything?