ieee8023 / covid-chestxray-dataset

We are building an open database of COVID-19 cases with chest X-ray or CT images.
3k stars 1.28k forks source link

COVID-19 classification DCNN training code with "explainability" functionality #37

Closed mansilla closed 3 years ago

mansilla commented 4 years ago

In this example, we use ONLY the XRs samples in the dataset labeled as COVID-19. We went the XRs way instead of the CTs since there are more of them. But I agree CTs are better for detection as mentioned here #5 .

The Neural Network source code is based in a post by Adrian Rosebrock in PyImageSearch.

Here, the dataset was divided into two labels: sicks and healthy. The healthy training samples were extracted from this Kaggle contest.

Then for training, we divide into two folders /dataset/sicks and /dataset/healthy, located in the root folder. Each class having the same number of images (around 90).

It's a preliminary approach that may improve substantially once the dataset grows enough.

from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import VGG16
from tensorflow.keras.layers import AveragePooling2D
from tensorflow.keras.layers import Dropout
from tensorflow.keras.layers import Flatten
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import Input
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.utils import to_categorical
from sklearn.preprocessing import LabelBinarizer
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix
from imutils import paths
import matplotlib.pyplot as plt
import numpy as np
import cv2
import os
import lime
from lime import lime_image
from skimage.segmentation import mark_boundaries

plt.rcParams["figure.figsize"] = (20,10)

## global params
INIT_LR = 1e-4  # learning rate
EPOCHS = 21  # training epochs
BS = 8  # batch size

## load and prepare data
imagePaths = list(paths.list_images("dataset"))
data = []
labels = []
# loop over the image paths
for imagePath in imagePaths:
    # extract the class label from the filename
    label = imagePath.split(os.path.sep)[-2]
    # load the image, swap color channels, and resize it to be a fixed
    # 224x224 pixels while ignoring aspect ratio
    image = cv2.imread(imagePath)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image = cv2.resize(image, (224, 224))
    # update the data and labels lists, respectively
    data.append(image)
    labels.append(label)
# convert the data and labels to NumPy arrays while scaling the pixel
# intensities to the range [0, 1]
data = np.array(data) / 255.0
labels = np.array(labels)

TEST_SET_SIZE = 0.2

lb = LabelBinarizer()
labels = lb.fit_transform(labels)
labels = to_categorical(labels); print(labels)
# partition the data into training and testing splits using 80% of
# the data for training and the remaining 20% for testing
(trainX, testX, trainY, testY) = train_test_split(data, labels,
    test_size=TEST_SET_SIZE, stratify=labels, random_state=42)
# initialize the training data augmentation object
trainAug = ImageDataGenerator(
    rotation_range=15,
    fill_mode="nearest")

## build network
baseModel = VGG16(weights="imagenet", include_top=False,
    input_tensor=Input(shape=(224, 224, 3)))
# construct the head of the model that will be placed on top of the
# the base model
headModel = baseModel.output
headModel = AveragePooling2D(pool_size=(4, 4))(headModel)
headModel = Flatten(name="flatten")(headModel)
headModel = Dense(64, activation="relu")(headModel)
headModel = Dropout(0.5)(headModel)
headModel = Dense(2, activation="softmax")(headModel)
# place the head FC model on top of the base model (this will become
# the actual model we will train)
model = Model(inputs=baseModel.input, outputs=headModel)
# loop over all layers in the base model and freeze them so they will
# *not* be updated during the first training process
for layer in baseModel.layers:
    layer.trainable = False

print("[INFO] compiling model...")
opt = Adam(lr=INIT_LR, decay=INIT_LR / EPOCHS)
model.compile(loss="binary_crossentropy", optimizer=opt,
    metrics=["accuracy"])

## train
print("[INFO] training head...")
H = model.fit_generator(
    trainAug.flow(trainX, trainY, batch_size=BS),
    steps_per_epoch=len(trainX) // BS,
    validation_data=(testX, testY),
    validation_steps=len(testX) // BS,
    epochs=EPOCHS)

print("[INFO] saving COVID-19 detector model...")
model.save("covid19.model", save_format="h5")

## eval
print("[INFO] evaluating network...")
predIdxs = model.predict(testX, batch_size=BS)
predIdxs = np.argmax(predIdxs, axis=1) # argmax for the predicted probability
print(classification_report(testY.argmax(axis=1), predIdxs,
    target_names=lb.classes_))

cm = confusion_matrix(testY.argmax(axis=1), predIdxs)
total = sum(sum(cm))
acc = (cm[0, 0] + cm[1, 1]) / total
sensitivity = cm[0, 0] / (cm[0, 0] + cm[0, 1])
specificity = cm[1, 1] / (cm[1, 0] + cm[1, 1])
# show the confusion matrix, accuracy, sensitivity, and specificity
print(cm)
print("acc: {:.4f}".format(acc))
print("sensitivity: {:.4f}".format(sensitivity))
print("specificity: {:.4f}".format(specificity))

## explain
N = EPOCHS
plt.style.use("ggplot")
plt.figure()
plt.plot(np.arange(0, N), H.history["loss"], label="train_loss")
plt.plot(np.arange(0, N), H.history["val_loss"], label="val_loss")
plt.plot(np.arange(0, N), H.history["accuracy"], label="train_acc")
plt.plot(np.arange(0, N), H.history["val_accuracy"], label="val_acc")
plt.title("Precision of COVID-19 detection.")
plt.xlabel("Epoch #")
plt.ylabel("Loss/Accuracy")
plt.legend(loc="lower left")
plt.savefig("training_plot.png")

for ind in range(10): 
    explainer = lime_image.LimeImageExplainer()
    explanation = explainer.explain_instance(testX[-ind], model.predict,
                                             hide_color=0, num_samples=42)
    print("> label:", testY[ind].argmax(), "- predicted:", predIdxs[ind])

    temp, mask = explanation.get_image_and_mask(
    explanation.top_labels[0], positive_only=False, num_features=1, hide_rest=True)
    plt.imshow(mark_boundaries(temp / 2 + 0.5, mask)+testX[ind])
    plt.show()

In the end, you will have some visualizations on how the network is "detecting" (if the evaluation metrics make sense) COVID-19 suspicious region in the XRs.

sample_detection

Comment 1: In my experience, this Lime explanation method can be handy when classifying images and trying to understand what the network is actually "looking at" to make the decision.

Comment 2: I was wondering why the classification accuracy was so high here (and in the original PyImageSearch post). I think it is because the Kaggle dataset is so well standardized that the NN is learning to predict where the X-Ray comes from Kaggle or this dataset instead of classifying healthy/sick. Nevertheless, I feel that the source code is still relevant, and with more XRs data and better preprocessing, we will be able to fix this issue and improve the algorithm.

fmobrj commented 4 years ago

Well, I dont know if that is the reason why the classification is so good. I say that because I created a 3rd category, dividing the Kaggle Dataset between "normal" and "other pneumonia" and the model still can almost perfectly divide the 3 classes: "normal", "covid" and "other viral or bacterial pneumonia" (97,5% acc). So we still have to see if the high accuracy is caused by the image origin or any other bias associated with the images.

ajaymaity commented 4 years ago

There's a huge difference in the way Kaggle datasets are being captured compared to the X-ray images in this repo. For example, the Kaggle images almost always has dark edges on the left and right side of the image, which the images in this repo doesn't. I performed a simple comparison to add the pixels on the edges of the two dataset sources after binarizing (cropped the images 10 pixels on the left and 10 pixels on the right), and you can see the distribution below. The normal dataset from Kaggle has high black pixels compared to covid images in this repo.

Covid dataset from this repo: image

Normal dataset from Kaggle: image

Kernel Density Distribution: image

So, I wouldn't believe the system performing classificiation on such kinds of classes. The ML algorithm just have to learn to check if the edges contain black pixels or not, and that's it you have a high accuracy.

mansilla commented 4 years ago

@fmobrj it was a supposition I made based on the visualizations I got with Lime. Also, is really weird the performance is so high with so little amount of images.

@ajaymaity yes, that's what I meant. Although I'm not sure if it is this black border you mention or anything else. And there's no way to know it with a kernel density distribution of the images that's for sure. We make these kinds of suppositions all the time, and most of the time, they are all wrong.

The only way to know what the NN is doing is to visualize the attention of the network. That's the whole point of this code.

In any case, I feel lung detection is a significant improvement in the preprocessing that needs to be done. If we detected the lung border and blacked the exterior of that area, we force the network "to look" only inside the lungs. This kernel provides a good example of what I am saying. I will add it as soon as I can.

ieee8023 commented 4 years ago

Those kaggle pneumonia images that people keep using are of children. The model can learn age to tell them apart. Maybe the explanation of the image is showing how big the chest is as a predictor.

I would suggest using this data of mostly adults with pneumonia (which is also from a kaggle challenge which could be confusing, let's call this one the RSNA kaggle dataset): https://academictorrents.com/details/95588a735c9ae4d123f3ca408e56570409bcf2a9

fmobrj commented 4 years ago

Those kaggle pneumonia images that people keep using are of children. The model can learn age to tell them apart. Maybe the explanation of the image is showing how big the chest is as a predictor.

I would suggest using this data of mostly adults with pneumonia (which is also from a kaggle challenge which could be confusing, let's call this one the RSNA kaggle dataset): https://academictorrents.com/details/95588a735c9ae4d123f3ca408e56570409bcf2a9

Indeed. Using the RSNA images, the accuracy drops to ~ 80% to me. Now, will try some self supervised learning using all images from both datasets (RSNA and yours) for pretraining on a pretext task and see if the performance improves. For classification I am training a 3 class classification model: covid, normal and other lung conditions.

Thanks, Joseph.

bganglia commented 4 years ago

@ieee8023 I wrote some code to implement @mansilla's lung selection idea. I could contribute it to torchxrayvision if you think there could be a place for preprocessing routines.

from keras.models import load_model
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import cv2

#Download model from https://github.com/imlab-uiip/lung-segmentation-2d/blob/master/trained_model.hdf5?raw=true
lung_finder = load_model("trained_model.hdf5?raw=true") #edit path as necessary

def lungs_from_chest(
        radiograph, 
        max_value=255, 
        thresh=20, 
        out_size=None,
        mask_out = False):
    """Crop out everything but the lungs in a chest X-ray."""
    #Note final size
    if out_size is None:
        out_size = radiograph.shape[:2]
    #Identify lungs
    radiograph = cv2.resize(radiograph,(256,256))
    pred_radiograph = np.array([radiograph/max_value])
    if len(pred_radiograph.shape) == 3:
        pred_radiograph = np.expand_dims(pred_radiograph,3)
    lung_confidence = lung_finder.predict(pred_radiograph)[0]
    #Find contours
    mask = (lung_confidence * 255 > 1).astype(np.uint8)
    im2, contours, hierarchy = cv2.findContours(mask, 1, 2)
    #Find bounds of all contours
    points = []
    for contour in contours:
        area = cv2.contourArea(contour)
        if area > thresh:
            rect = cv2.minAreaRect(contour)
            box = cv2.boxPoints(rect)
            points.extend(box)
    points = np.array(points).astype(np.uint8)
    x, y = zip(*points)
    #Optionally mask out
    if mask_out:
        radiograph = radiograph * mask
    #Select only lungs
    cropped = radiograph[min(x):max(x), min(y):max(y)]
    #Resize to correct size
    return cv2.resize(cropped, out_size)

This code also lives at https://github.com/bganglia/chest_xray_preprocessing/blob/master/lungs_from_chest.py

pA1nD commented 4 years ago

Has anyone tryed with Stanfords Chexpert? E. G. For pretraining? https://stanfordmlgroup.github.io/competitions/chexpert/

fmobrj commented 4 years ago

Takin a look at it now.

mansilla commented 4 years ago

@bganglia thank you for this! Let's make it work on the dataset @ieee8023 suggested, the RSNA Kaggle dataset: https://academictorrents.com/details/95588a735c9ae4d123f3ca408e56570409bcf2a9

JiayuanDing100 commented 4 years ago

@mansilla Hi where do you find label labeled as COVID-19 or not? I didn't find it in metadata and image data

mansilla commented 4 years ago

@JiayuanDing100 hey. I used the metadata

armiro commented 4 years ago

Those kaggle pneumonia images that people keep using are of children. The model can learn age to tell them apart. Maybe the explanation of the image is showing how big the chest is as a predictor.

I would suggest using this data of mostly adults with pneumonia (which is also from a kaggle challenge which could be confusing, let's call this one the RSNA kaggle dataset): https://academictorrents.com/details/95588a735c9ae4d123f3ca408e56570409bcf2a9

Exactly. I asked radiologists about this high accuracy problem and they told me the same. Pediatric chests are of a different shape from normal adult/elder chests, so even basic models can detect with high accuracy.

armiro commented 4 years ago

Has anyone tryed with Stanfords Chexpert? E. G. For pretraining? https://stanfordmlgroup.github.io/competitions/chexpert/

I collected and trained on them. Still getting high accuracies but not as high as previous RSNA images. Detected features (using Grad-CAM) are still not the ones expected to be seen (patchy opacities mostly peripheral). I think the only way is to increase the number of COVID pneumonia cases and collect normal cases from different datasets. Segmenting lobes from the images is also a good idea.