keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
61.6k stars 19.41k forks source link

PyDataset Documentation and Best Practices #20142

Open dryglicki opened 3 weeks ago

dryglicki commented 3 weeks ago

Keras Version: 3.5.0 Tensorflow Version: 2.17.0

What I want to do: Use PyDataset class in a data distributed environment.


I would like to ask about the status of PyDataset and some of its best uses and practices. I have a functioning PyDataset class that ingests and processes HDF files:

class HDFDataset(K.utils.PyDataset):
    '''
    Keras data loader to replace Tensorflow's Dataset API.
    Reads HDF5 files.
    Inputs:
        file_list: list
            list of file names, pre-globbed
        batch_size: int
            size of batches
        shuffle: bool
            whether or not to shuffle the dataset at the end of each epoch
        lons_lats: bool
            whether or not to include longitudes and latitudes

        -- Additional keyword arguments --
        workers=1
        use_multiprocessing=False
        max_queue_size=10
    '''
    def __init__(self,
            file_list: list | tuple | set,
            batch_size: int,
            shuffle: bool = False,
            lons_lats: bool = False,
            subsample: bool = False,
            **kwargs):
        super(HDFDataset, self).__init__(**kwargs)
        self.shuffle = shuffle
        self.batch_size = batch_size
        self.tmplen = len(self.file_list)

        self.subsample = subsample
        if self.subsample:
            self.slice = slice(64, 192)
            self.time_slice = slice(0,6)

    def __len__(self):
        return self.tmplen // self.batch_size

    def _extract_data_from_hdf5(self, file_list):
        input_list = ['priors', 'model']
        output_list = ['forecast']

        # Preparing input dictionary
        inputs_dict = {}
        for name in input_list:
            new_var = f'input_{name}'
            inputs_dict[new_var] = []

        outputs = []

        for f in file_list:
            with h5py.File(f, 'r') as h5:
                for k in input_list:
                    new_var = f'input_{k}'
                    if self.subsample:
                        inputs_dict[new_var].append(h5.get(k)[:, self.slice, self.slice, :])
                    else:
                        inputs_dict[new_var].append(h5.get(k)[...])
                for k in output_list:
                    if self.subsample:
                        outputs.append(h5.get(k)[0:6, self.slice, self.slice, :])
                    else:
                        outputs.append(h5.get(k)[0:6, self.slice, self.slice, :])

        for k in input_list:
            nv = f'input_{k}'
            inputs_dict[nv] = np.stack(inputs_dict[nv], axis = 0)

        outputs = np.stack(outputs, axis = 0)

        return inputs_dict, outputs

    def __getitem__(self,
            idx: int):

        if idx >= self.__len__(): raise StopIteration

        low = idx * self.batch_size

        high = min(low + self.batch_size, self.tmplen)

        inputs, outputs = self._extract_data_from_hdf5(self.file_list[low:high])

        return [inputs, outputs]

    def on_epoch_end(self):
        if self.shuffle: random.shuffle(self.file_list) # In-place shuffle

This works for my case really nicely. It avoids the memory leak nightmare with which I have been dealing by directly trying to use the tf.data API (https://github.com/tensorflow/tensorflow/issues/72014) for multiple inputs from the same file.

But the documentation on PyDataset stinks!

Looking inside the source code, PyDataset has an Adapter class that will make a Tensorflow data generator. Does this automatically get called during fit()? Is it best practice to call the data generator directly so I can distribute the dataset via TF's experimental distribute dataset function?

In the source, there is also a PyDatasetEnqueuer class. Do I need this? Why is this here? Who is the target audience? Is the expectation of the Enquerer in the PyDataset class also the reason I need to raise a StopIteration command in __getitem__?

Also digging into source, at this point, the shuffle is hard-coded to 8. That probably needs to go.

Anyway, I don't have any specific programming questions here, but I would like to know what best practices are, how do I use PyDataset in a (Tensorflow) distributed data environment, and so on.

fchollet commented 2 weeks ago

In the source, there is also a PyDatasetEnqueuer class. Do I need this? Why is this here? Who is the target audience? Is the expectation of the Enquerer in the PyDataset class also the reason I need to raise a StopIteration command in getitem?

You should not ever need to use it. It's internal.

Looking inside the source code, PyDataset has an Adapter class that will make a Tensorflow data generator. Does this automatically get called during fit()? Is it best practice to call the data generator directly so I can distribute the dataset via TF's experimental distribute dataset function?

You can call it yourself, but you don't have to. If you don't, the framework will distribute your dataset for you.

dryglicki commented 2 weeks ago

Thanks @fchollet. I was too quick with the send, and that does appear to be happening. What is throwing me is that in my example, there's a shuffle attribute that gets propagated down to the tf.data call and the shuffle buffer is getting filled without my asking for it to do so explicitly. I think that's a bug.

ghsanti commented 2 days ago

a shuffle attribute that gets propagated down to the tf.data call and the shuffle buffer is getting filled without my asking for it to do so explicitly. I think that's a bug.

This is the Trainer's fit.shuffle description:

shuffle: Boolean, whether to shuffle the training data before each epoch. This argument is ignored when x is a generator or a tf.data.Dataset.


So it's True for PyDataset (unless it's infinite), False for tf.data.Dataset imo.

@dryglicki