Closed ghost closed 3 years ago
Yes. You should be able to take the TensorFlowDataLoader class and re-define the get_train_loader and get_valid_loader classes.
It'd be something like this (note: I haven't verified this code works):
from openfl.federated import TensorFlowDataLoader from .brats_utils import load_from_nifti
class TensorFlowBratsPrefetch(TensorFlowDataLoader):
"""TensorFlow Data Loader for the BraTS dataset."""
def __init__(self, data_path, batch_size, percent_train=0.8, pre_split_shuffle=True, **kwargs):
"""Initialize.
Args:
data_path: The file path for the BraTS dataset
batch_size (int): The batch size to use
percent_train (float): The percentage of the data to use for training (Default=0.8)
pre_split_shuffle (bool): True= shuffle the dataset before
performing the train/validate split (Default=True)
**kwargs: Additional arguments, passed to super init and load_from_nifti
Returns:
Data loader with BraTS data
"""
super().__init__(batch_size, **kwargs)
self.batch_size = batch_size
self.num_train = int(self.numFiles * self.train_test_split)
numValTest = self.numFiles - self.num_train
ds = tf.data.Dataset.range(self.numFiles).shuffle(
self.numFiles, self.random_seed) # Shuffle the dataset
ds_train = ds.take(self.num_train).shuffle(
self.num_train, self.shard) # Reshuffle based on shard
ds_val_test = ds.skip(self.num_train)
self.num_val = int(numValTest * self.validate_test_split)
self.num_test = self.num_train - self.num_val
ds_val = ds_val_test.take(self.num_val)
ds_test = ds_val_test.skip(self.num_val)
ds_train = ds_train.map(lambda x: tf.py_function(self.read_nifti_file,
[x, True], [tf.float32, tf.float32]),
num_parallel_calls=tf.data.experimental.AUTOTUNE)
ds_val = ds_val.map(lambda x: tf.py_function(self.read_nifti_file,
[x, False], [tf.float32, tf.float32]),
num_parallel_calls=tf.data.experimental.AUTOTUNE)
ds_test = ds_test.map(lambda x: tf.py_function(self.read_nifti_file,
[x, False], [tf.float32, tf.float32]),
num_parallel_calls=tf.data.experimental.AUTOTUNE)
ds_train = ds_train.repeat()
ds_train = ds_train.batch(self.batch_size)
ds_train = ds_train.prefetch(tf.data.experimental.AUTOTUNE)
ds_val = ds_val.batch(batch_size_val)
ds_val = ds_val.prefetch(tf.data.experimental.AUTOTUNE)
ds_test = ds_test.batch(batch_size_test)
ds_test = ds_test.prefetch(tf.data.experimental.AUTOTUNE)
self.ds_test = ds_test
self.ds_train = ds_train
self.ds_val = ds_val
def get_feature_shape(self):
"""
Get the shape of an example feature array.
Returns:
tuple: shape of an example feature array
"""
return tf.shape(self.ds_train.take(1)[0]) # This is probably not correct; better to figure out in the init function
def get_train_loader(self, batch_size=None):
"""
Get training data loader.
Returns
-------
loader object
"""
return self.ds_train
def get_valid_loader(self, batch_size=None):
"""
Get validation data loader.
Returns:
loader object
"""
return self.ds_val
def get_train_data_size(self):
"""
Get total number of training samples.
Returns:
int: number of training samples
"""
return self.num_train
def get_valid_data_size(self):
"""
Get total number of validation samples.
Returns:
int: number of validation samples
"""
return self.num_val
Thank you very much for your reply!
Hi there,
I am trying to process some 3D medical images (some .nii.gz files) with openFL but I am having some trouble doing so. My data loader is as follows: (data loader from 3D_unet model)
, which output some PrefetchObjects ds_train, ds_val and ds_test. However, according to the data loader file, I believe OpenFL is expecting data loaders to outputs _X_train, y_train, X_valid, y_valid, and some follow-up operations (e.g., get batch) with them. I personally found it easier if we can have an option to use the PrefetchObjects directly instead of converting them to X_train, y_train_ etc.
So I was wondering if OpenFL can have some ways to enable data loaders for the nii.gz files?
Thank you so much for your attention!