albrateanu / LYT-Net

LYT-Net: Lightweight YUV Transformer-based Network for Low-Light Image Enhancement
https://arxiv.org/abs/2401.15204
MIT License
58 stars 10 forks source link

PyTorch version training issues #14

Open longlong161 opened 1 month ago

longlong161 commented 1 month ago

After training the PyTorch version for a while, the metrics will not change. I trained twice and this problem started at 454 epochs. image image

By the way, can we also add SSIM evaluation metrics to the PyTorch version? Thank you

albrateanu commented 1 month ago

Hello. Will add SSIM evaluation today. About the training issues. I still need to fix the model forward pass a bit. Moreover, training on PyTorch has different behaviour to TensorFlow apparently, so I will still be checking to see how to bring it as close as possible to the TensorFlow initial implementation.

longlong161 commented 1 month ago

Hello. Will add SSIM evaluation today. About the training issues. I still need to fix the model forward pass a bit. Moreover, training on PyTorch has different behaviour to TensorFlow apparently, so I will still be checking to see how to bring it as close as possible to the TensorFlow initial implementation.

Okay, thank you for your reply. Do you still need to add files related to testing in the future? Looking forward to the complete PyTorch version

stuhao251 commented 1 month ago

感觉是在那个轮次训练已经达到饱和了,哈哈哈

albrateanu commented 1 month ago

image @longlong161 Don't know exactly what the issue with training is for you. I didn't have that. Looks like a gradient problem. Please retry with the last code and let me know if it still persists.

@stuhao251 It might be the case, unfortunately. I still get better metrics when training with TF for some reason. However, the baseline PyTorch implementation of the model should be exactly the same as the TF version, as both the TF and Torch versions of the model have 44923 parameters. So please feel free to develop on it if you think there's space for improvement.

longlong161 commented 1 month ago

@albrateanu I used the latest code you uploaded and encountered a new error that appeared on both machines. image image

GZY2000 commented 1 month ago

@albrateanu I used the latest code you uploaded and encountered a new error that appeared on both machines. image image

Hello, has pytorch version come out yet? Why is my result so much worse than the results of tensorflow in the paper?

albrateanu commented 1 month ago

@longlong161 I'm not sure how I can help with that. Please make sure the environment is ok and running PyTorch 2. And perhaps just attempt retraining more? Or try setting either the Smooth L1 loss or the MS-SSIM loss weight to 0? It seems like a gradient explosion problem and it's not very clear what causes it given that I didn't have that. @GZY2000 I keep saying this, I am not sure. I am able to obtain at least around 0.8-1dB more on TensorFlow than on PyTorch. But it's a bit out of my knowledge.

albrateanu commented 1 month ago

@GZY2000 I have just posted PyTorch weights. On LOLv2 Real, PyTorch implementations has 0.61dB more than TensorFlow.

GZY2000 commented 1 month ago

@GZY2000 I have just posted PyTorch weights. On LOLv2 Real, PyTorch implementations has 0.61dB more than TensorFlow.

That's true, but synthetic is more than 2db off on this dataset, which is the best result you can get after a lot of training!