jonbarron / robust_loss_pytorch

A pytorch port of google-research/google-research/robust_loss/
Apache License 2.0
656 stars 88 forks source link

Loss drives below zero for self-supervised depth estimation and training fails #16

Open rvarun7777 opened 4 years ago

rvarun7777 commented 4 years ago

Hi, unfortunately we don't have a code release for the monocular depth estimation experiments of the paper (though that code is in TF anyways so it likely isn't what you're looking for). I believe that there are Pytorch implementations of SFMLearner on Github, and using this loss should be straightforward: just delete the existing multiscale photometric loss and the smoothness term and add in AdaptiveImageLossFunction on the full-res image with: scale_lo=0.01 scale_init=0.01 and default settings for the rest and it should work (you may need to fiddle with the value of wavelet_scale_base).

Do you suggest any other changes apart from this? I am actually testing this :)

If you call the loss in general.py, it should never go negative. The loss produced by adaptive.py is actually a negative log-likelihood, so it can go negative depending on the value of the scale parameter. Could you kindly elaborate on this? Doesn't the -log(y) yield positive values. Example: NLL: -ln(0.5) = 0.69 How does it drive negative or do we consider the abs(l1) finally?

After adding your loss it just drives it below 0 and the training fails.

Tested your loss in YUV space on monodepth2 project. AdaptiveImageLossFunction((3, 640, 192), torch.float32, "cuda:0", color_space='YUV', scale_lo=0.01, scale_init=0.01) Note: Alpha and Scale are added to the optimizers. I have skipped this part here. These were the initial settings used. Actually, the scale of final loss of monodepth with L1 is around 0.07 to 0.02 and performs well.

Test 1 Final loss = 0.85 Adaptive Loss + 0.15 SSIM + 0.001 Smoothness Test 2: Tried different weightage as well. This setting would fail after few epochs as well. Final loss = 0.15 Adaptive Loss + 0.85 SSIM + 0.001 Smoothness

jonbarron commented 4 years ago

Cool, I've gotten many requests for this SFMLearner result in Pytorch but I don't have cycles to take care of it myself, so I'm happy to help.

Here's the part of the main paper that describes important changes: We keep our loss’s scale c fixed to 0.01, thereby matching the fixed scale assumption of the baseline model and roughly matching the shape of its L1 loss (Eq. 15). To avoid exploding gradients we multiply the loss being minimized by c, thereby bounding gradient magnitudes by residual magnitudes (Eq. 14).

Section H of the appendix goes through all the changes that were made to the codebase:

So with this, looking at the code you sent me, It looks like you should do: 0.01 * AdaptiveImageLossFunction((3, 640, 192), torch.float32, "cuda:0", color_space='YUV', scale_lo=0.01, scale_init=0.01, wavelet_scale_base=$SOMETHING) And then just set the final loss to be 100% this adaptive loss (no SSIM or smoothness). I think wavelet_scale_base should be 2, but the interface for that experiment isn't the same as what is here so try [0.5, 1, 2]. And be sure to delete the secondary branches of the code that evaluate this loss at multiple scales, you should only need to evaluate it at the finest scale.

On Mon, Mar 9, 2020 at 3:44 AM Varun Ravikumar notifications@github.com wrote:

Hi, unfortunately we don't have a code release for the monocular depth estimation experiments of the paper (though that code is in TF anyways so it likely isn't what you're looking for). I believe that there are Pytorch implementations of SFMLearner on Github, and using this loss should be straightforward: just delete the existing multiscale photometric loss and the smoothness term and add in AdaptiveImageLossFunction on the full-res image with: scale_lo=0.01 scale_init=0.01 and default settings for the rest and it should work (you may need to fiddle with the value of wavelet_scale_base).

Do you suggest any other changes apart from this? I am actually testing this :)

Tested your loss in YUV space on monodepth2 project. AdaptiveImageLossFunction((3, 640, 192), torch.float32, "cuda:0", color_space='YUV', scale_lo=0.01, scale_init=0.01) These were the initial setting used. Actually, the scale of final loss of monodepth with L1 is around 0.02. After adding your loss it just drives it below 0. Final loss = 0.85 Adaptive Loss + 0.15 SSIM + 0.01 Smoothness

— You are receiving this because you are subscribed to this thread. Reply to this email directly, view it on GitHub https://github.com/jonbarron/robust_loss_pytorch/issues/16?email_source=notifications&email_token=AAGZFNQEYEXADKKEJG3TDFDRGTCANA5CNFSM4LEF3YD2YY3PNVWWK3TUL52HS4DFUVEXG43VMWVGG33NNVSW45C7NFSM4ITQXSCA, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAGZFNSV2JT4IWW3LL4ZDW3RGTCANANCNFSM4LEF3YDQ .

rvarun7777 commented 4 years ago

Thanks for the tips. Just one small note, why do we need to remove SSIM and Smoothness loss? Isn"t structural similarity and smoothing of the depth quite important?

jonbarron commented 4 years ago

I don't remember the TF SFMLearner code having an SSIM loss, so I'm not sure what to make of that. But in my experiments, performance was nearly identical / slightly better without a smoothness loss --- it only adds value if the loss being used as the data term is falling short.

On Tue, Mar 10, 2020 at 12:07 PM Varun Ravikumar notifications@github.com wrote:

Thanks for the tips. Just one small note, why do we need to remove SSIM and Smoothness loss? Isn"t structural similarity and smoothing of the depth quite important?

— You are receiving this because you commented. Reply to this email directly, view it on GitHub https://github.com/jonbarron/robust_loss_pytorch/issues/16?email_source=notifications&email_token=AAGZFNXJJ5LYRVDGQOORU2LRG2FYZA5CNFSM4LEF3YD2YY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGOEOMXX4A#issuecomment-597261296, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAGZFNRBN666DFBDZVM2DITRG2FYZANCNFSM4LEF3YDQ .

rvarun7777 commented 4 years ago

I will try it out. Will share my findings here. I think you missed this one!

If you call the loss in general.py, it should never go negative. The loss produced by adaptive.py is actually a negative log-likelihood, so it can go negative depending on the value of the scale parameter. Could you kindly elaborate on this? Doesn't the -log(y) yield positive values. Example: NLL: -ln(0.5) = 0.69 How does it drive negative or do we consider the abs(l1) finally?

jonbarron commented 4 years ago

There's no reason for NLLs to be non-negative. If a likelihood is > 1, the corresponding NLL will be < 0. Basic distributions like Gaussians do this, if you set sigma to a small value. This will only happen in this code when the scale parameter is small, but it should not be any cause for concern.

On Tue, Mar 10, 2020 at 12:41 PM Varun Ravikumar notifications@github.com wrote:

I will try it out. Will share my findings here. I think you missed this one!

If you call the loss in general.py, it should never go negative. The loss produced by adaptive.py is actually a negative log-likelihood, so it can go negative depending on the value of the scale parameter. Could you kindly elaborate on this? Doesn't the -log(y) yield positive values. Example: NLL: -ln(0.5) = 0.69 How does it drive negative or do we consider the abs(l1) finally?

— You are receiving this because you commented. Reply to this email directly, view it on GitHub https://github.com/jonbarron/robust_loss_pytorch/issues/16?email_source=notifications&email_token=AAGZFNXTLI6OAUS7CO4YXNLRG2JUXA5CNFSM4LEF3YD2YY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGOEOM3QHI#issuecomment-597276701, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAGZFNT4ESZVZLQWKPRCDSLRG2JUXANCNFSM4LEF3YDQ .

rvarun7777 commented 4 years ago

be sure to delete the secondary branches of the code that evaluate this loss at multiple scales, you should only need to evaluate it at the finest scale.

In monodepth2 all the lower scales are upsampled to the highest scale and photometric error is applied on them.

Instead of computing the photometric error on the ambiguous low-resolution images, we first upsample the lower resolution depth maps (from the intermediate layers) to the input image resolution, and then reproject, resample, and compute the error pe at this higher input resolution. This procedure is similar to matching patches, as low-resolution disparity values will be responsible for warping an entire ‘patch’ of pixels in the high resolution image. This effectively constrains the depth maps at each scale to work toward the same objective i.e. reconstructing the high resolution input target image as accurately as possible.

jonbarron commented 4 years ago

Oh sorry, the advice I was giving was for SFMLearner, which is the codebase I did my experiments for the paper with. That advice may not apply to Monodepth2, though I'd still try just imposing a single loss on the finest scale.

On Tue, Mar 10, 2020 at 2:54 PM Varun Ravikumar notifications@github.com wrote:

be sure to delete the secondary branches of the code that evaluate this loss at multiple scales, you should only need to evaluate it at the finest scale.

In monodepth2 all the lower scales are upsampled to the highest scale and photometric error is applied on them.

Instead of computing the photometric error on the ambiguous low-resolution images, we first upsample the lower resolution depth maps (from the intermediate layers) to the input image resolution, and then reproject, resample, and compute the error pe at this higher input resolution. This procedure is similar to matching patches, as low-resolution disparity values will be responsible for warping an entire ‘patch’ of pixels in the high resolution image. This effectively constrains the depth maps at each scale to work toward the same objective i.e. reconstructing the high resolution input target image as accurately as possible.

— You are receiving this because you commented. Reply to this email directly, view it on GitHub https://github.com/jonbarron/robust_loss_pytorch/issues/16?email_source=notifications&email_token=AAGZFNUETLNXIDD64ACALKLRG2ZLHA5CNFSM4LEF3YD2YY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGOEONKENQ#issuecomment-597336630, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAGZFNVAK34A4JXMUAIFRFTRG2ZLHANCNFSM4LEF3YDQ .

yellowYuga commented 3 years ago

There's no reason for NLLs to be non-negative. If a likelihood is > 1, the corresponding NLL will be < 0. Basic distributions like Gaussians do this, if you set sigma to a small value. This will only happen in this code when the scale parameter is small, but it should not be any cause for concern. On Tue, Mar 10, 2020 at 12:41 PM Varun Ravikumar @.***> wrote: I will try it out. Will share my findings here. I think you missed this one! If you call the loss in general.py, it should never go negative. The loss produced by adaptive.py is actually a negative log-likelihood, so it can go negative depending on the value of the scale parameter. Could you kindly elaborate on this? Doesn't the -log(y) yield positive values. Example: NLL: -ln(0.5) = 0.69 How does it drive negative or do we consider the abs(l1) finally? — You are receiving this because you commented. Reply to this email directly, view it on GitHub <#16?email_source=notifications&email_token=AAGZFNXTLI6OAUS7CO4YXNLRG2JUXA5CNFSM4LEF3YD2YY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGOEOM3QHI#issuecomment-597276701>, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAGZFNT4ESZVZLQWKPRCDSLRG2JUXANCNFSM4LEF3YDQ .

So there is no need to take abs(AdaptiveImageLoss) when loss goes negative? I use this for HITNet.(https://github.com/google-research/google-research/tree/master/hitnet) And set: self.robust = AdaptiveImageLossFunction((960, 320, 1), float_dtype=torch.float32, device=0).

rvarun7777 commented 3 years ago

@yellowYuga Do not know if this snippet helps you or not. I used it for depth estimation. For more information you can look it up here: SynDistNet

from robust_loss_pytorch import AdaptiveImageLossFunction
self.adaptive_image_loss = AdaptiveImageLossFunction((3, args.input_width, args.input_height),
                                                     torch.float32, args.device, color_space='YUV',
                                                     scale_lo=0.01, scale_init=0.01)
l1_loss = (0.01 * self.adaptive_image_loss.lossfun(torch.abs(target - predicted))).mean(1, True)
yellowYuga commented 3 years ago

@yellowYuga Do not know if this snippet helps you or not. I used it for depth estimation. For more information you can look it up here: SynDistNet

from robust_loss_pytorch import AdaptiveImageLossFunction
self.adaptive_image_loss = AdaptiveImageLossFunction((3, args.input_width, args.input_height),
                                                     torch.float32, args.device, color_space='YUV',
                                                     scale_lo=0.01, scale_init=0.01)
l1_loss = (0.01 * self.adaptive_image_loss.lossfun(torch.abs(target - predicted))).mean(1, True)

In my experiment,the residual between predicted and target doesn't take abs.That means: loss = self.robust_loss.lossfun(target - predicted).mean() Is that correct? (BTW,when I tested it in fitting curve y=x^2,lossfun(target-pred) is better than lossfun(abs(target-pred)))

jonbarron commented 3 years ago

Do not minimize abs(AdaptiveImageLoss), it's totally normal for the loss to go below zero, and adding an abs() will completely break it. Adding an abs() to the input residual passed to the loss will have no effect, so you can do it if you want, but I don't see why you would (the residual is immediately squared in the loss function.

jonbarron commented 3 years ago

self.robust_loss.lossfun(target - predicted).mean() looks right to me.