keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
62.14k stars 19.49k forks source link

PyDataset Documentation and Best Practices #20142

Open dryglicki opened 3 months ago

dryglicki commented 3 months ago

Keras Version: 3.5.0 Tensorflow Version: 2.17.0

What I want to do: Use PyDataset class in a data distributed environment.


I would like to ask about the status of PyDataset and some of its best uses and practices. I have a functioning PyDataset class that ingests and processes HDF files:

class HDFDataset(K.utils.PyDataset):
    '''
    Keras data loader to replace Tensorflow's Dataset API.
    Reads HDF5 files.
    Inputs:
        file_list: list
            list of file names, pre-globbed
        batch_size: int
            size of batches
        shuffle: bool
            whether or not to shuffle the dataset at the end of each epoch
        lons_lats: bool
            whether or not to include longitudes and latitudes

        -- Additional keyword arguments --
        workers=1
        use_multiprocessing=False
        max_queue_size=10
    '''
    def __init__(self,
            file_list: list | tuple | set,
            batch_size: int,
            shuffle: bool = False,
            lons_lats: bool = False,
            subsample: bool = False,
            **kwargs):
        super(HDFDataset, self).__init__(**kwargs)
        self.shuffle = shuffle
        self.batch_size = batch_size
        self.tmplen = len(self.file_list)

        self.subsample = subsample
        if self.subsample:
            self.slice = slice(64, 192)
            self.time_slice = slice(0,6)

    def __len__(self):
        return self.tmplen // self.batch_size

    def _extract_data_from_hdf5(self, file_list):
        input_list = ['priors', 'model']
        output_list = ['forecast']

        # Preparing input dictionary
        inputs_dict = {}
        for name in input_list:
            new_var = f'input_{name}'
            inputs_dict[new_var] = []

        outputs = []

        for f in file_list:
            with h5py.File(f, 'r') as h5:
                for k in input_list:
                    new_var = f'input_{k}'
                    if self.subsample:
                        inputs_dict[new_var].append(h5.get(k)[:, self.slice, self.slice, :])
                    else:
                        inputs_dict[new_var].append(h5.get(k)[...])
                for k in output_list:
                    if self.subsample:
                        outputs.append(h5.get(k)[0:6, self.slice, self.slice, :])
                    else:
                        outputs.append(h5.get(k)[0:6, self.slice, self.slice, :])

        for k in input_list:
            nv = f'input_{k}'
            inputs_dict[nv] = np.stack(inputs_dict[nv], axis = 0)

        outputs = np.stack(outputs, axis = 0)

        return inputs_dict, outputs

    def __getitem__(self,
            idx: int):

        if idx >= self.__len__(): raise StopIteration

        low = idx * self.batch_size

        high = min(low + self.batch_size, self.tmplen)

        inputs, outputs = self._extract_data_from_hdf5(self.file_list[low:high])

        return [inputs, outputs]

    def on_epoch_end(self):
        if self.shuffle: random.shuffle(self.file_list) # In-place shuffle

This works for my case really nicely. It avoids the memory leak nightmare with which I have been dealing by directly trying to use the tf.data API (https://github.com/tensorflow/tensorflow/issues/72014) for multiple inputs from the same file.

But the documentation on PyDataset stinks!

Looking inside the source code, PyDataset has an Adapter class that will make a Tensorflow data generator. Does this automatically get called during fit()? Is it best practice to call the data generator directly so I can distribute the dataset via TF's experimental distribute dataset function?

In the source, there is also a PyDatasetEnqueuer class. Do I need this? Why is this here? Who is the target audience? Is the expectation of the Enquerer in the PyDataset class also the reason I need to raise a StopIteration command in __getitem__?

Also digging into source, at this point, the shuffle is hard-coded to 8. That probably needs to go.

Anyway, I don't have any specific programming questions here, but I would like to know what best practices are, how do I use PyDataset in a (Tensorflow) distributed data environment, and so on.

fchollet commented 3 months ago

In the source, there is also a PyDatasetEnqueuer class. Do I need this? Why is this here? Who is the target audience? Is the expectation of the Enquerer in the PyDataset class also the reason I need to raise a StopIteration command in getitem?

You should not ever need to use it. It's internal.

Looking inside the source code, PyDataset has an Adapter class that will make a Tensorflow data generator. Does this automatically get called during fit()? Is it best practice to call the data generator directly so I can distribute the dataset via TF's experimental distribute dataset function?

You can call it yourself, but you don't have to. If you don't, the framework will distribute your dataset for you.

dryglicki commented 3 months ago

Thanks @fchollet. I was too quick with the send, and that does appear to be happening. What is throwing me is that in my example, there's a shuffle attribute that gets propagated down to the tf.data call and the shuffle buffer is getting filled without my asking for it to do so explicitly. I think that's a bug.

ghsanti commented 2 months ago

a shuffle attribute that gets propagated down to the tf.data call and the shuffle buffer is getting filled without my asking for it to do so explicitly. I think that's a bug.

This is the Trainer's fit.shuffle description:

shuffle: Boolean, whether to shuffle the training data before each epoch. This argument is ignored when x is a generator or a tf.data.Dataset.


So it's True for PyDataset (unless it's infinite), False for tf.data.Dataset imo.

CDKnightNASA commented 2 months ago

I'm also running into issues with subclassing keras.utils.PyDataset, namely that I found that I had to do bounds-checking on __getitem__() to ensure that idx <= the value of __len__() and raise an IndexError() exception if it exceeds the length. I was surprised that, when I pass an instance of my PyDataset to the numpy.array constructor, it calls __len__() but then iterates forever unless I add this bounds check.

Some other comments: I'm guessing this is to support infinite datasets, but this should be documented in the PyDataset docs. Similarly, there was a change to use the num_batches property in lieu of len in one of the more recent commits, but there was no documentation to indicate we should implement a num_batches property.

Also, in my case, I need to use "floor" ( len // batch_size ) as I want all of my batches to be of the same size.

Lastly, an example of sliding window dataset would be appreciated. I assume that __len__() would be array size - batch size and __getitem__() would return array[idx:idx + batch size].

dryglicki commented 2 months ago

@CDKnightNASA I had to issue StopIteration manually in my code to get it to stop.

I need to test the shuffle. In my code above, I manually shuffle via on_epoch_end. Having it filled up in a tf.data construct was unexpected.