qubvel / segmentation_models

Segmentation models with pretrained backbones. Keras and TensorFlow Keras.
MIT License
4.76k stars 1.03k forks source link

Question about class_weights in loss functions #318

Open EtagiBI opened 4 years ago

EtagiBI commented 4 years ago

Hello,

I'm a bit confused by class_weights parameter in loss functions. For instance, in JaccardLoss description it says: class_weights: Array (``np.array``) of class weights (``len(weights) = num_classes``) So, it's a numpy array. What dimension should it be? I assume it should be (image width x image height x number of classes). But how can I assign specific weights to specific classes? I beleive the process should be different gor binary segmentation and for multiclass segmentation problems.

JordanMakesMaps commented 4 years ago

It doesn't need to be an numpy array, and it's definitely not a 2D array. It can just be a python list where each index contains the weight for the class whose channel is the same when one-hot-encoded.

class_weights = [.5, .1, .95] 

The number of indexes/classes is just to tell a helper function which channel to pull from both the ground-truth and the prediction:

# base/functional.py, line 88
gt, pr = gather_channels(gt, pr, indexes=class_indexes, **kwargs)

The determination of what the class weight values should be, is the question to ask. Usually people will use sklearn.utils.class_weight.compute_class_weight, but you could implement your own version by calculating the numbers of samples per class and dividing the total number of samples over each one. See here for discussion. How you calculate the number of samples in a multi-channel image? I'm not sure how others do it, but I count the number of pixels for each class in the whole dataset, along with the total number of pixels in the dataset and compute the weights from that.

In this case, it's not different for binary and multi-class classification because if you think about it, binary classification in semantic segmentation is still two classes. If you're looking at pictures of cats, the two classes are 'cat', and 'not cat'.

...

# score calculation
intersection = backend.sum(gt * pr, axis=axes)
union = backend.sum(gt + pr, axis=axes) - intersection

score = (intersection + smooth) / (union + smooth)
score = average(score, per_image, class_weights, **kwargs) 

return score

...

def average(x, per_image=False, class_weights=None, **kwargs):
    backend = kwargs['backend']
    if per_image:
        x = backend.mean(x, axis=0)
    if class_weights is not None:
        x = x * class_weights
    return backend.mean(x)
EtagiBI commented 4 years ago

Thanks for your reply!

I was confused by a PyCharm warning. It claims that list is an unexpected data type for class_weights parameter of loss functions.

As for difference between binary segmentation and multiclass segmentation, I thought that one-hot encoding isn't widely used for binary problems since it's possible to determine both classes using one mask with a single channel.

JordanMakesMaps commented 4 years ago

If you don't one-hot-encode the binary segmentation map, how do you calculate the loss over the background class category? Would you mind providing a reference from somewhere that discusses this?

EtagiBI commented 4 years ago

The determination of what the class weight values should be, is the question to ask. Usually people will use sklearn.utils.class_weight.compute_class_weight, but you could implement your own version by calculating the numbers of samples per class and dividing the total number of samples over each one.

Well, perhaps I reinvented the wheel, but here's my approach to determining class weights:

def calc_weights(masks_folder):
    """Calculate class weights according to classes distribution in a dataset"""
    images_list = os.listdir(masks_folder)
    class_1_numbers = []
    for i in range(len(images_list)):
        mask = cv2.imread(masks_folder + '/' + images_list[i], cv2.IMREAD_GRAYSCALE) / 255.
        class_1_numbers.append(cv2.countNonZero(mask))

    class_1_total = int(statistics.median(class_1_numbers))
    class_0_total = int(IMAGE_WIDTH**2 - class_1_total)
    class_1_weight = 1. # Maximum value to minority class
    class_0_weight = class_1_total / class_0_total # Proportional value to majority class for classes balance
    return [class_0_weight, class_1_weight]

class_weights = calc_weights(path_to_masks_folder)

I wonder if it's better to calculate class weights over individual batches instead of a whole dataset?

If you don't one-hot-encode the binary segmentation map, how do you calculate the loss over the background class category?

Actually, it's just my naive understanding of things. We have binary_crossentropy for binary segmentation tasks and cathegorical_crossentropy for multiclass segmentation tasks. Whereas techincally it's possible to use cathegorical_crossentropy for all segmentation tasks, it's recommeded to use binary_crossentropy instead. I assume that's because there's no need to use multiple filter layers for binary masks. All 1s and 0s are already there. If I'm wrong, it should be nice to underatsnd the reason behind existance of both binary_crossentropy and cathegorical_crossentropy.

EtagiBI commented 4 years ago

@JordanMakesMaps what's the purpose of class_indexes? For instance, I have binary masks that contain objects (white) and background (black). Here's my Dataset class:

class Dataset:
    """CamVid Dataset. Read images, apply augmentation and preprocessing transformations.

    Args:
        images_dir (str): path to images folder
        masks_dir (str): path to segmentation masks folder
        class_values (list): values of classes to extract from segmentation mask
        augmentation (albumentations.Compose): data transfromation pipeline
            (e.g. flip, scale, etc.)
        preprocessing (albumentations.Compose): data preprocessing
            (e.g. noralization, shape manipulation, etc.)

    """

    CLASSES = ['object', 'background'] # <----------------------- my classes (class 1 and class 0)

    def __init__(
            self,
            images_dir,
            masks_dir,
            classes=None,
            augmentation=None,
            preprocessing=None,
    ):
        self.ids = os.listdir(images_dir)
        self.images_fps = [os.path.join(images_dir, image_id) for image_id in self.ids]
        self.masks_fps = [os.path.join(masks_dir, image_id) for image_id in self.ids]

        # convert str names to class values on masks
        self.class_values = [self.CLASSES.index(cls.lower()) for cls in classes]

        self.augmentation = augmentation
        self.preprocessing = preprocessing

    def __getitem__(self, i):

        # read data
        image = cv2.imread(self.images_fps[i])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(self.masks_fps[i], 0)

        # extract certain classes from mask (e.g. cars)
        masks = [(mask == v) for v in self.class_values]
        mask = np.stack(masks, axis=-1).astype('float')

        # add background if mask is not binary
        if mask.shape[-1] != 1:
            background = 1 - mask.sum(axis=-1, keepdims=True)
            mask = np.concatenate((mask, background), axis=-1)

        # apply augmentations
        if self.augmentation:
            sample = self.augmentation(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']

        # apply preprocessing
        if self.preprocessing:
            sample = self.preprocessing(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']

        return image, mask

    def __len__(self):
        return len(self.ids)

Since I'd like to detect objects and don't care about background, I put 'object' in my global variable CLASSES and create datasets:

CLASSES = ['objects']

train_dataset = Dataset(
    x_train_dir,
    y_train_dir,
    classes=CLASSES,
    augmentation=get_training_augmentation(),
    preprocessing=get_preprocessing(preprocess_input),
)
valid_dataset = Dataset(
    x_valid_dir,
    y_valid_dir,
    classes=CLASSES,
    augmentation=get_validation_augmentation(),
    preprocessing=get_preprocessing(preprocess_input),
)

Then I create weights to balance my classes:

class_weights = [class_1_weight, class_0_weight] # <----------------------- class_1_weight should be the first, because 'object' class is the first in CLASSES, whereas 'background' is the second
jaccard_loss = sm.losses.JaccardLoss(class_weights = class_weights )
metrics = [sm.metrics.IOUScore(threshold=0.5, class_weights = class_weights )]
optim = keras.optimizers.Adam(LR)
model.compile(optim, jaccard_loss, metrics)

If everything in my example is correct, then it's possible to regulate class indexes by weights' order in a list. class_1_weightgoes to objects(since these two hold first positions in their lists), class_0_weightgoes to background (since these two hold second positions in their lists).

So, what's thr application area of class_indexes?

JordanMakesMaps commented 4 years ago

See #340, someone else asked a very similar question. the class_index parameter tells the gather_channels() function which channels in the one-hot-encoded vector to pull from to calculate the loss and/or metrics.

  1. if you provide nothing to either class_indexes or class_weights nothing occurs to the metrics/loss function (i.e. all channels are used).
  2. if you provide some numbers in a list as class_indexes, only those channels will be pulled for calculating the metric/loss; if you also provide the same number of weights within an array to class_weights, then those weights will be applied (if there is a length mismatch, you get an error).

So its purpose is so that if you're not using all of the channels in the one-hot-encoded mask (for whatever reason), you can still apply weights to the channels that you are using. If you're using all of the channels then you can ignore it and just use the class_weights parameter. At least, that's what I get from the code, but I think it's just an additional feature that some find useful depending on their dataset. You could easily ignore a class by having its weight set to zero, or you could use class_index to remove it from being calculated. Different strokes for different folks I guess.