xia-xx-cv / EDDFS_dataset

a retinal fundus dataset for eye disease diagnosis and fundus synthesis
MIT License
3 stars 0 forks source link

The model underperforms on EDDFS data and Other dataset #2

Open mystvearn opened 2 months ago

mystvearn commented 2 months ago

Thank you for releasing the data and code for reference. I learned from your code and write my own code to run inference on single image. I used your well-trained model for multi-lable multi-disease classification task. The result doesn't seem to be correct as it fails to predict most of images extracted from other dataset. Could you please advice on what could go wrong with my code?

import os
import numpy as np
import torch
from torchvision import transforms
from PIL import Image
import cv2
from models import coattnet_v2_withWeighted_tiny  # Ensure this model is in your models folder
from models import resnet18
import random

# Custom configuration for multi-label, multi-disease classification
class EDDFS_delN_ml_conf:
    def __init__(self):
        self.task = "multi_labels"
        self.classes_num = 8
        self.classes_names = ['DR', 'AMD', 'glaucoma', 'myo', 'rvo', 'LS', 'hyper', 'others']
        self.image_size = 448
        self.mean_brightness = [0.485, 0.456, 0.406]
        self.std_brightness = [0.229, 0.224, 0.225]

    def print_info(self):
        print("EDDFS multi-label multi-disease without normal samples")

def clahe_preprocess(image_path, denoise=False, contrastenhancement=True, brightnessbalance=None, cliplimit=2, gridsize=8, mask_path='./mask.png'):
    """This function applies CLAHE and other preprocessing steps to the input image."""
    bgr = cv2.imread(image_path)

    mask_img = cv2.imread(mask_path, 0)
    z = mask_img.shape[0] * mask_img.shape[1] - mask_img.sum() / 255.

    # Brightness balance
    if brightnessbalance:
        gray = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY)
        brightness = gray.sum() / z
        bgr = np.uint8(np.minimum(bgr * brightnessbalance / brightness, 255))

    # Contrast enhancement with CLAHE
    if contrastenhancement:
        lab = cv2.cvtColor(bgr, cv2.COLOR_BGR2LAB)
        lab_planes = list(cv2.split(lab))
        clahe = cv2.createCLAHE(clipLimit=cliplimit, tileGridSize=(gridsize, gridsize))
        lab_planes[0] = clahe.apply(lab_planes[0])
        lab = cv2.merge(tuple(lab_planes))
        bgr = cv2.cvtColor(lab, cv2.COLOR_LAB2BGR)

    # Denoising
    if denoise:
        bgr = cv2.fastNlMeansDenoisingColored(bgr, None, 10, 10, 1, 3)
        bgr = cv2.bilateralFilter(bgr, 5, 1, 1)

    # Convert the BGR image to RGB
    rgb_image = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)

    return rgb_image

def load_image_with_custom_preprocess(image_path, image_size, mask_path):
    # Apply CLAHE and other preprocessing
    processed_image = clahe_preprocess(image_path, mask_path=mask_path)

    # Convert to PIL image to apply further transformations
    pil_image = Image.fromarray(processed_image)

    # Define further transformations (normalization, resize, etc.)
    preprocess = transforms.Compose([
        transforms.Resize(image_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    # Apply preprocessing
    img_tensor = preprocess(pil_image).unsqueeze(0)
    return img_tensor

# Evaluation function
def evaluate_single_image(model, image_tensor, device, opt_dataset):
    model.eval()
    with torch.no_grad():
        image_tensor = image_tensor.to(device)
        output = model(image_tensor)

        # Apply sigmoid for multi-label classification
        sigmoid_output = torch.sigmoid(output)

        # Threshold the output to make a binary decision (presence/absence of each label)
        predictions = torch.round(sigmoid_output)

        probabilities = sigmoid_output.cpu().numpy()

        # Print the probabilities for each label
        for idx, prob in enumerate(probabilities[0]):
            print(f"Probability of {opt_dataset.classes_names[idx]}: {prob:.4f}")

        # Convert predictions to class names
        predicted_labels = [opt_dataset.classes_names[i] for i, val in enumerate(predictions[0]) if val == 1]

        return predicted_labels

if __name__ == "__main__":
    # Load model configuration
    opt_dataset = EDDFS_delN_ml_conf()
    opt_dataset.print_info()

    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Set random seed for reproducibility
    seed = 2022
    random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # Load the model (dynamically load coattnet_v2_withWeighted_tiny)
    net_name = 'coattnet_v2_withWeighted_tiny'
    classes_num = opt_dataset.classes_num

    loc = {"create_model": resnet18()}  # Placeholder for dynamically loading the model
    glb = {}

    exec("from models import {} as create_model".format(net_name), glb, loc)
    create_model = loc["create_model"]

    # Create the model with the correct number of classes
    model = create_model(num_classes=classes_num, pretrained=True)

    # Load the model weights
    model_path = "weights/NC_delN_ml_7_448-parallelnet_v2_withWeighted_tiny_e51_bFalse_bs32-l9e-05_0.2-preFalse-lossbce.pth.tar"
    if os.path.isfile(model_path):
        print(f"=> Loading checkpoint '{model_path}'")
        checkpoint = torch.load(model_path, map_location=device)
        model.load_state_dict(checkpoint['state_dict'], strict=True)  # Load with strict=False to ignore mismatched keys
        print('Model loaded successfully')
    else:
        raise ValueError(f"No checkpoint found at '{model_path}'")

    # Set model to evaluation mode
    model.to(device)

    # Load the image with custom preprocessing
    image_path = "test_images/0021_crop.jpg"  # Path to your test image
    mask_path = "./mask.png"  # Provide the correct path to the mask image used during training
    image_tensor = load_image_with_custom_preprocess(image_path, opt_dataset.image_size, mask_path)

    # Run inference using the evaluation function
    predicted_labels = evaluate_single_image(model, image_tensor, device, opt_dataset)

    # Output the prediction results
    print("Predicted diseases:")
    if predicted_labels:
        print(", ".join(predicted_labels))
    else:
        print("No diseases detected")

Below are the images that I tested the model on

csr2_crop csr3_crop dr1 dr1_crop dr2 dr3 dr3_crop dr4 RET004OS_crop

3664(SLFY29) crvo1_crop csr1_crop

0001(TJZX9) 0010_crop 0020_crop 0021_crop

xia-xx-cv commented 1 month ago

Thank you for releasing the data and code for reference. I learned from your code and write my own code to run inference on single image. I used your well-trained model for multi-lable multi-disease classification task. The result doesn't seem to be correct as it fails to predict most of images extracted from other dataset. Could you please advice on what could go wrong with my code?

@mystvearn

Here are the results using preprocess without brightness balance, as it may be sensitive to the dataset's mean brightness (alternatives include '2', '4', and '6'), with a threshold of 0.4 on a Mac M3 Pro.

multi-label: DR AMD glaucoma myo rvo LS hyper others test0 :- AMD - - - - - - test1 :- AMD glaucoma - - - - - test2 :- AMD - - - - - - test3 :- - glaucoma - - - - - test4 :- - glaucoma - - - - - test5 :DR - - - - - - - test6 :- - - - - LS - others test7 :- - - - rvo - - - test8 :- AMD - - - - - - test9 :DR - - - - - hyper others

multi-class: test0:myo test1:glaucoma test2:others test3:others test4:others test5:DR test6:LS test7:rvo test8:others test9:rvo

I believe the model achieved some correct classifications for AMD, LS, and DR. I attribute several mistakes, such as with myopia, to the fact that the model was trained on the EDDFS dataset, which has an imbalanced sample distribution. Besides, the model misclassified some cases of glaucoma, suggesting that glaucoma diagnosis may require more fine-grained annotations.

I have added the test_single_img.py into the repo, and updated the training, testing and dataset codes for an improvement.