albrateanu / LYT-Net

LYT-Net: Lightweight YUV Transformer-based Network for Low-Light Image Enhancement
https://arxiv.org/abs/2401.15204
MIT License
58 stars 10 forks source link

histogram_loss has no gradient #16

Closed russellllaputa closed 5 days ago

russellllaputa commented 3 weeks ago

Dear authors,

Excellent work!

There is a problem I met when I tried to re-train your model. I found when only using histogram_loss, it pops out the following error: RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn. Does this mean that your histogram loss will not contribute to parameter updates? From what I understand Histogram is discrete, so it cannot be used as a loss function.

Looking forward to your reply. Many thanks

albrateanu commented 2 weeks ago

Hi, thanks for writing.

You are correct that the histogram itself is discrete. However, the histogram_loss is a difference between the prediction and ground-truth histograms, which can be differentiated.

I will check whether this is an issue related to how the loss is computed however.

I understand this is for the PyTorch implementation?

liuxiaoya396 commented 2 weeks ago

Hi, I met the same problem in PyTorch implementation : the derivative for ‘histc’ is not implemented. Is there a way to easily implement a histogram function with derivates?

Thank you for your time

russellllaputa commented 2 weeks ago

Hi, thanks for writing.

You are correct that the histogram itself is discrete. However, the histogram_loss is a difference between the prediction and ground-truth histograms, which can be differentiated.

I will check whether this is an issue related to how the loss is computed however.

I understand this is for the PyTorch implementation?

Thanks for your reply. Yes, I meant the PyTorch version, but I also checked the func you used in the Tensorflow version, and it also seems it has no derivates

albrateanu commented 5 days ago

Hello. Histogram loss has been updated to use gaussian kernels to smoothen the distribution over bins, and essentially implement soft-binning instead of hard binning - to remain differentiable.

Thanks for your comment.