JuliaWolleb / diffusion-anomaly

Anomaly detection with diffusion models
MIT License
115 stars 24 forks source link

Data Preprocess code #19

Open waiterxiaoyy opened 1 year ago

waiterxiaoyy commented 1 year ago

Preprocess code

First of all, I want to express my appreciation for the author's work and contribution. However, since the author didn't provide the preprocessing code for the dataset in the GitHub repository, I've taken the initiative to create and share my preprocessing code.

While the author utilized the BRATS2020 dataset, my preprocessing code is based on the BRATS2021 dataset. In essence, there's not much difference between the two. Despite the differences, the core process remains the same.

In the BRATS2021 dataset, there are a total of 1251 patient data, each containing four modalities of data as well as segmentation labels corresponding to the targets.

Following the author's paper and code example, I've chosen to extract slices for each modality and the segmentation target, specifically focusing on the slices ranging from 80 to 129. This means discarding the first 80 and the last 26 slices.

The code consists of two main parts:

  1. Organizing each modality and segmentation target's slices into individual directories(eg. 000001).
  2. Shuffling and splitting the directories into a 9:1 ratio for training and validation sets. These sets are stored in the 'training' and 'testing' directories respectively. It is worth noting that each directory in the 'testing' set contains segmentation target label files.

If you wish to strictly follow the code example, you can extract the segmentation labels files separately from 'testing'.

In the BRATS2021 dataset, there is a larger volume compared to BRATS2020.

Based on my trials, the training set comprises a total of 52126 slices, with 18203 being healthy and 33923 being abnormal. The testing set contains 5792 slices, out of which 1991 are healthy and 3801 are abnormal. Due to randomness, the exact numbers of healthy and abnormal slices may vary with each code run, but the proportional difference should not be substantial.

Below is the preprocessing code:

import os
import numpy as np
import nibabel
import random
import shutil

# Dataset storage location
data_path = './datasets/brats/data'

# Location to save preprocessed data
save_path = './datasets/brats/anomaly_brats'

if not os.path.exists(save_path):
    os.mkdir(save_path)

dir_list = os.listdir(data_path)

modalities = ['t1', 't1ce', 't2', 'flair', 'seg']

def preprocess():
    num = 1

    for index, dir in enumerate(dir_list):

        # Skip if the directory contains this file (Brats2021 dataset specific)
        if dir == '.DS_Store':
            continue

        print(f"{index} / {len(dir_list)}")

        patient_path = os.path.join(data_path, dir)

        model_data = {}
        for model in modalities:
            filename = dir + "_" + model + ".nii.gz"
            file_path = os.path.join(patient_path, filename)
            data = nibabel.load(file_path).get_fdata()
            model_data[model] = data

        for i in range(80, 129):
            file_num = str(num).zfill(6)
            save_slice_path = os.path.join(save_path, file_num)
            if not os.path.exists(save_slice_path):
                os.mkdir(save_slice_path)

            for model in modalities:
                file_name = dir + "_" + model + "_" + str(i).zfill(3) + ".nii.gz"

                save_model_path = os.path.join(save_slice_path, file_name)

                if model == 'seg':
                    # Map label values: 0, 1, 2, 4 to 0, 1, 2, 3
                    label = model_data[model][..., i]
                    label[label == 4] = 3
                    label = nibabel.Nifti1Image(label, affine=np.eye(4))
                    nibabel.save(label, save_model_path)
                else:
                    img_data = model_data[model]
                    x = img_data[..., i] - np.nanmin(img_data[..., i])
                    y = np.nanmax(img_data[..., i]) - np.nanmin(img_data[..., i])
                    y = y if y != 0 else 1.0
                    img = x / y  # (240, 240)

                    if img.max() > 1.0 or img.min() < 0:
                        print(f"--Error: {num} --")

                    img = nibabel.Nifti1Image(img, affine=np.eye(4))
                    nibabel.save(img, save_model_path)
            num += 1

def split_data():
    # Set paths and directory names
    training_path = os.path.join(save_path, "training")
    testing_path = os.path.join(save_path, "testing")

    # Create directories for training and testing sets
    os.makedirs(training_path, exist_ok=True)
    os.makedirs(testing_path, exist_ok=True)

    # Get list of files
    dir_list = os.listdir(save_path)

    # Shuffle the list of files
    random.shuffle(dir_list)

    # Calculate the number of samples for training and testing sets
    total_samples = len(dir_list)

    # Exclude the 'training' and 'testing' directories created earlier
    total_samples = total_samples - 2
    train_samples = int(0.9 * total_samples)
    test_samples = total_samples - train_samples

    train_health_num = 0
    test_health_num = 0

    for i, dir_name in enumerate(dir_list):
        print(f"{i} / {total_samples}")

        if dir_name == 'training' or dir_name == 'testing':
            continue

        source_dir_path = os.path.join(save_path, dir_name)

        if i < train_samples:
            file = os.listdir(source_dir_path)

            # Extract label to determine health status
            seg_files = [file_name for file_name in file if "seg" in file_name]

            if len(seg_files) == 0:
                print("---")
            seg_file = os.path.join(source_dir_path, seg_files[0])
            image = nibabel.load(seg_file).get_fdata()

            if image.max() == 0:
                train_health_num += 1

            destination_dir_path = os.path.join(training_path, dir_name)

        else:
            file = os.listdir(source_dir_path)

            seg_files = [file_name for file_name in file if "seg" in file_name]
            if len(seg_files) == 0:
                print("---")
            seg_file = os.path.join(source_dir_path, seg_files[0])
            image = nibabel.load(seg_file).get_fdata()
            if image.max() == 0:
                test_health_num += 1
            destination_dir_path = os.path.join(testing_path, dir_name)

            # Extract 'seg' for test_labels separately
            # ... (write your code here)

        # Move directories
        shutil.move(source_dir_path, destination_dir_path)

    print(f"Training set: Healthy {train_health_num}, Abnormal {train_samples - train_health_num}, Total {train_samples}")
    print(f"Testing set: Healthy {test_health_num}, Abnormal {test_samples - test_health_num}, Total {test_samples}")

if __name__ == '__main__':
    preprocess()

    split_data()