DhairyaLGandhi / UNet.jl

Generic UNet implementation written in pure Julia, based on Flux.jl
MIT License
49 stars 19 forks source link

Add pixelwise loss weight? #13

Open tlnagy opened 4 years ago

tlnagy commented 4 years ago

In the original implementation, they used a weighted loss function to weight up border pixels so that the network learns those preferentially (see Fig 3D below).

image

Do you have any suggestions for how to implement this in UNet.jl? I'm still really new to Flux so sorry if this is obvious. My guess would be to implement it in loss()

https://github.com/DhairyaLGandhi/UNet.jl/blob/954c89e8e2a9dd4cfad8c265a9cda03eef85e6f5/src/utils.jl#L49-L52

EDIT: Here's an implementation of the pixel-wise weights for Keras: https://jaidevd.github.io/posts/weighted-loss-functions-for-instance-segmentation/

DhairyaLGandhi commented 3 years ago

Could we try with a translation first? I am not familiar with their implementation, but seems to be a combination of masking and weighting. Shouldn't be too difficult at all.

tlnagy commented 3 years ago

What do you mean by a translation first?

DhairyaLGandhi commented 3 years ago

I meant translating the loss from keras, sorry I should've been clearer.