securefederatedai / openfl

An open framework for Federated Learning.
https://openfl.readthedocs.io/en/latest/index.html
Apache License 2.0
716 stars 196 forks source link

For 3D Image data #106

Closed ghost closed 3 years ago

ghost commented 3 years ago

Hi there,

I am trying to process some 3D medical images (some .nii.gz files) with openFL but I am having some trouble doing so. My data loader is as follows: (data loader from 3D_unet model)

def get_dataset(self):

    self.num_train = int(self.numFiles * self.train_test_split)
    numValTest = self.numFiles - self.num_train
    ds = tf.data.Dataset.range(self.numFiles).shuffle(
        self.numFiles, self.random_seed)  # Shuffle the dataset
    ds_train = ds.take(self.num_train).shuffle(
        self.num_train, self.shard)  # Reshuffle based on shard
    ds_val_test = ds.skip(self.num_train)
    self.num_val = int(numValTest * self.validate_test_split)
    self.num_test = self.num_train - self.num_val
    ds_val = ds_val_test.take(self.num_val)
    ds_test = ds_val_test.skip(self.num_val)

    ds_train = ds_train.map(lambda x: tf.py_function(self.read_nifti_file,
                                                     [x, True], [tf.float32, tf.float32]),
                            num_parallel_calls=tf.data.experimental.AUTOTUNE)
    ds_val = ds_val.map(lambda x: tf.py_function(self.read_nifti_file,
                                                 [x, False], [tf.float32, tf.float32]),
                        num_parallel_calls=tf.data.experimental.AUTOTUNE)
    ds_test = ds_test.map(lambda x: tf.py_function(self.read_nifti_file,
                                                   [x, False], [tf.float32, tf.float32]),
                          num_parallel_calls=tf.data.experimental.AUTOTUNE)

    ds_train = ds_train.repeat()
    ds_train = ds_train.batch(self.batch_size)
    ds_train = ds_train.prefetch(tf.data.experimental.AUTOTUNE)

    batch_size_val = 4
    ds_val = ds_val.batch(batch_size_val)
    ds_val = ds_val.prefetch(tf.data.experimental.AUTOTUNE)

    batch_size_test = 1
    ds_test = ds_test.batch(batch_size_test)
    ds_test = ds_test.prefetch(tf.data.experimental.AUTOTUNE)

    return ds_train, ds_val, ds_test

, which output some PrefetchObjects ds_train, ds_val and ds_test. However, according to the data loader file, I believe OpenFL is expecting data loaders to outputs _X_train, y_train, X_valid, y_valid, and some follow-up operations (e.g., get batch) with them. I personally found it easier if we can have an option to use the PrefetchObjects directly instead of converting them to X_train, y_train_ etc.


So I was wondering if OpenFL can have some ways to enable data loaders for the nii.gz files?

Thank you so much for your attention!

tonyreina commented 3 years ago

Yes. You should be able to take the TensorFlowDataLoader class and re-define the get_train_loader and get_valid_loader classes.

https://github.com/intel/openfl/blob/6191a61ada83e08b9a5503c64e39ebc22ecbb7ff/openfl/federated/data/loader_tf.py#L44

It'd be something like this (note: I haven't verified this code works):

from openfl.federated import TensorFlowDataLoader from .brats_utils import load_from_nifti

class TensorFlowBratsPrefetch(TensorFlowDataLoader):
    """TensorFlow Data Loader for the BraTS dataset."""

    def __init__(self, data_path, batch_size, percent_train=0.8, pre_split_shuffle=True, **kwargs):
        """Initialize.

        Args:
            data_path: The file path for the BraTS dataset
            batch_size (int): The batch size to use
            percent_train (float): The percentage of the data to use for training (Default=0.8)
            pre_split_shuffle (bool): True= shuffle the dataset before
            performing the train/validate split (Default=True)
            **kwargs: Additional arguments, passed to super init and load_from_nifti

        Returns:
            Data loader with BraTS data
        """
        super().__init__(batch_size, **kwargs)

        self.batch_size = batch_size
        self.num_train = int(self.numFiles * self.train_test_split)
    numValTest = self.numFiles - self.num_train
    ds = tf.data.Dataset.range(self.numFiles).shuffle(
        self.numFiles, self.random_seed)  # Shuffle the dataset
    ds_train = ds.take(self.num_train).shuffle(
        self.num_train, self.shard)  # Reshuffle based on shard
    ds_val_test = ds.skip(self.num_train)
    self.num_val = int(numValTest * self.validate_test_split)
    self.num_test = self.num_train - self.num_val
    ds_val = ds_val_test.take(self.num_val)
    ds_test = ds_val_test.skip(self.num_val)

    ds_train = ds_train.map(lambda x: tf.py_function(self.read_nifti_file,
                                                     [x, True], [tf.float32, tf.float32]),
                            num_parallel_calls=tf.data.experimental.AUTOTUNE)
    ds_val = ds_val.map(lambda x: tf.py_function(self.read_nifti_file,
                                                 [x, False], [tf.float32, tf.float32]),
                        num_parallel_calls=tf.data.experimental.AUTOTUNE)
    ds_test = ds_test.map(lambda x: tf.py_function(self.read_nifti_file,
                                                   [x, False], [tf.float32, tf.float32]),
                          num_parallel_calls=tf.data.experimental.AUTOTUNE)

    ds_train = ds_train.repeat()
    ds_train = ds_train.batch(self.batch_size)
    ds_train = ds_train.prefetch(tf.data.experimental.AUTOTUNE)

    ds_val = ds_val.batch(batch_size_val)
    ds_val = ds_val.prefetch(tf.data.experimental.AUTOTUNE)

    ds_test = ds_test.batch(batch_size_test)
    ds_test = ds_test.prefetch(tf.data.experimental.AUTOTUNE)

    self.ds_test = ds_test
    self.ds_train = ds_train
    self.ds_val = ds_val

    def get_feature_shape(self):
        """
        Get the shape of an example feature array.

        Returns:
            tuple: shape of an example feature array
        """
        return tf.shape(self.ds_train.take(1)[0])  # This is probably not correct; better to figure out in the init function

    def get_train_loader(self, batch_size=None):
        """
        Get training data loader.

        Returns
        -------
        loader object
        """
        return self.ds_train

    def get_valid_loader(self, batch_size=None):
        """
        Get validation data loader.

        Returns:
            loader object
        """
        return self.ds_val

    def get_train_data_size(self):
        """
        Get total number of training samples.

        Returns:
            int: number of training samples
        """
        return self.num_train

    def get_valid_data_size(self):
        """
        Get total number of validation samples.

        Returns:
            int: number of validation samples
        """
        return self.num_val
ghost commented 3 years ago

Thank you very much for your reply!