Open dryglicki opened 3 months ago
In the source, there is also a PyDatasetEnqueuer class. Do I need this? Why is this here? Who is the target audience? Is the expectation of the Enquerer in the PyDataset class also the reason I need to raise a StopIteration command in getitem?
You should not ever need to use it. It's internal.
Looking inside the source code, PyDataset has an Adapter class that will make a Tensorflow data generator. Does this automatically get called during fit()? Is it best practice to call the data generator directly so I can distribute the dataset via TF's experimental distribute dataset function?
You can call it yourself, but you don't have to. If you don't, the framework will distribute your dataset for you.
Thanks @fchollet. I was too quick with the send, and that does appear to be happening. What is throwing me is that in my example, there's a shuffle
attribute that gets propagated down to the tf.data
call and the shuffle buffer is getting filled without my asking for it to do so explicitly. I think that's a bug.
a shuffle attribute that gets propagated down to the tf.data call and the shuffle buffer is getting filled without my asking for it to do so explicitly. I think that's a bug.
PyDataset
will be shuffled unless it's infinite. The shuffling is in the batch indices.
tf.data.Dataset
is assumed to be shuffled. See here.
This is the Trainer's fit.shuffle
description:
shuffle: Boolean, whether to shuffle the training data before each epoch. This argument is ignored when
x
is a generator or atf.data.Dataset
.
So it's True for PyDataset
(unless it's infinite), False for tf.data.Dataset
imo.
I'm also running into issues with subclassing keras.utils.PyDataset
, namely that I found that I had to do bounds-checking on __getitem__()
to ensure that idx
<= the value of __len__()
and raise an IndexError()
exception if it exceeds the length. I was surprised that, when I pass an instance of my PyDataset to the numpy.array
constructor, it calls __len__()
but then iterates forever unless I add this bounds check.
Some other comments:
I'm guessing this is to support infinite datasets, but this should be documented in the PyDataset docs. Similarly, there was a change to use the num_batches
property in lieu of len
in one of the more recent commits, but there was no documentation to indicate we should implement a num_batches
property.
Also, in my case, I need to use "floor" ( len // batch_size ) as I want all of my batches to be of the same size.
Lastly, an example of sliding window dataset would be appreciated. I assume that __len__()
would be array size - batch size and __getitem__()
would return array[idx:idx + batch size].
@CDKnightNASA I had to issue StopIteration
manually in my code to get it to stop.
I need to test the shuffle. In my code above, I manually shuffle via on_epoch_end
. Having it filled up in a tf.data
construct was unexpected.
Keras Version: 3.5.0 Tensorflow Version: 2.17.0
What I want to do: Use PyDataset class in a data distributed environment.
I would like to ask about the status of PyDataset and some of its best uses and practices. I have a functioning PyDataset class that ingests and processes HDF files:
This works for my case really nicely. It avoids the memory leak nightmare with which I have been dealing by directly trying to use the
tf.data
API (https://github.com/tensorflow/tensorflow/issues/72014) for multiple inputs from the same file.But the documentation on PyDataset stinks!
Looking inside the source code, PyDataset has an Adapter class that will make a Tensorflow data generator. Does this automatically get called during
fit()
? Is it best practice to call the data generator directly so I can distribute the dataset via TF's experimental distribute dataset function?In the source, there is also a
PyDatasetEnqueuer
class. Do I need this? Why is this here? Who is the target audience? Is the expectation of the Enquerer in the PyDataset class also the reason I need to raise aStopIteration
command in__getitem__
?Also digging into source, at this point, the shuffle is hard-coded to 8. That probably needs to go.
Anyway, I don't have any specific programming questions here, but I would like to know what best practices are, how do I use
PyDataset
in a (Tensorflow) distributed data environment, and so on.