shivahanifi / SCDD-image-segmentation-keras

Implementation of Segnet, FCN, UNet , PSPNet and other models in Keras.
https://divamgupta.com/image-segmentation/2019/06/06/deep-learning-semantic-segmentation-keras.html
MIT License
0 stars 0 forks source link

use different class weights for training #2

Open shivahanifi opened 3 days ago

shivahanifi commented 3 days ago

Specific class weights are focused on five classes in the dataset: background, ribbons, cracks, gridline defects, and inactive areas.

A two-level screening design was selected for the experiment because it is ideal for exploring many factors with a minimum of experimental runs. Each factor is set to a high and a low value to examine the main effects and interactions. The low values were set to 50% of the custom class weights, and the high values were set to 50% of the inverse class weights.

Screenshot from 2024-10-22 11-13-36

shivahanifi commented 1 day ago

For handling class imbalance, you can pass the class weights directly to the loss function.

  1. Defining class weights manually or using the sklearn function compute_class_weight
  2. customize the loss function e.g.
import keras.backend as K

def weighted_categorical_crossentropy(weights):
    def loss(y_true, y_pred):
        y_true = K.one_hot(K.cast(K.flatten(y_true), 'int32'), num_classes=len(weights))
        y_pred = K.flatten(y_pred)
        # Calculate weighted loss
        loss = K.categorical_crossentropy(y_true, y_pred)
        loss = K.sum(loss * K.constant(weights))
        return loss
    return loss

# Define the loss function with the computed class weights
custom_loss = weighted_categorical_crossentropy(list(class_weight_dict.values()))
  1. Compile the Model with Custom Loss
    
    # Re-compile the model with the custom loss function
    model.compile(optimizer='adam', loss=custom_loss, metrics=['accuracy'])

Now you can proceed with training

model.train( train_images=train_image_path, train_annotations=train_annotations_path, checkpoints_path=checkpoint_path, epochs=wandb.config.epochs, batch_size=2, steps_per_epoch=len(os.listdir(train_image_path)) // 2, callbacks=[WandbCallback(), checkpoint_callback] )

4. Log to wandb

Log class weights to WandB

wandb.config.update({ "class_weights": class_weight_dict })