mlyg / unified-focal-loss

Apache License 2.0
152 stars 22 forks source link

Generalization to 1d classification #9

Closed mphipps2 closed 2 years ago

mphipps2 commented 2 years ago

Hi,

Just wanted to start by thanking you. Unified and asymmetric unified focal loss have really worked well for me for 2d imbalanced segmentation.

That said, is there any reason not to generalize this for imbalanced 1d purposes? Currently the code throws this exception if you try to use it for 1d applications: 'Metric: Shape of tensor is neither 2D or 3D.'

Edit: I guess it's just the region-based side of the loss function that makes it not well-suited for 1d?

Thanks in advance, Mike

mlyg commented 2 years ago

Hi Mike,

Thank you for taking an interest in this project and I am glad to hear you have achieved good results!

I have not tested these loss functions on 1D data, but I believe the easiest way around the exception would be to replace the 'identify_axis' function with the function below:

def identify_axis(shape):
    # Three dimensional
    if len(shape) == 5 : return [1,2,3]
    # Two dimensional
    elif len(shape) == 4 : return [1,2]
    # One dimensional
    elif len(shape) == 3 : return [1]
    # Exception - Unknown
    else : raise ValueError('Metric: Shape of tensor is neither 1D or 2D or 3D.')

Please bear in mind that the loss function framework was developed with segmentation in mind (as you pointed out, because of the region-based losses). However, the distribution-based losses (cross entropy, focal loss etc) are used for classification and I would expect them to work (particularly the focal loss for imbalance data), and these can be accessed by either directly calling these losses or setting the 'weight' parameter to 0 in the unified focal loss.

Thanks, Michael