Open dennisobrien opened 2 years ago
@dennisobrien, thanks for creating the issue!
Have a few questions:
tf.data.Dataset
as a Flyte type if TFRecordFile
and TFRecordDirectory
Flyte types can handle that beneath? (see https://github.com/flyteorg/flytekit/pull/1240 PR for reference)tf.data.experimental.TFRecordWriter
work for any tf.data.Dataset
irrespective of calling the methods?tf.data.Dataset.save
not work in case the methods have already been called on tf.data.Dataset
or does it just slow down the tf.data.Dataset.load
operation?cc: @cosmicBboy
Curious about the answers to @samhita-alla's questions above.
My additional question:
tf.data.Dataset
to define their own dataset classes, or is tf.data.Dataset
the primary class they use?It would also help @dennisobrien if you could come up with some pseudocode snippets for how tf.data.Dataset
would be used at the interface of two Flyte tasks. This would clarify this caveat:
Since a tf.data.Dataset object can have steps in the pipelines that use local Python functions (e.g., a map or filter step), there doesn't seem to be a way to serialize the object without effectively "computing" the graph pipeline
The main thing I'm interested in is how tf.data.Dataset
would work in the following type of scenario:
@task
def t1(...) -> tf.data.Dataset:
dataset = tf.data.Dataset.range(100)
return (
dataset
... # a bunch of transformations, potentially with local functions
)
@task
def t2(dataset: tf.data.Dataset):
...
Suppose the Flyte TypeTransformer takes the output of t1
(after the bunch of transformations) and automatically calls dataset.save
... will this actually work? I.e. will the dataset be serialized with all the transformations applied to it, or will the dataset transforms be somehow serialized with it?
On the flipside, when t2
automatically deserializes the result of that with tf.data.Dataset.load
, will materializing the data actually work? Or would the user need to apply the same transformations to the raw data again?
Sorry for the delay in responding.
- Do we need to support tf.data.Dataset as a Flyte type if TFRecordFile and TFRecordDirectory Flyte types can handle that beneath?
The tf.data.Dataset
is a very different from a TFRecord
. The TFRecord
is really a binary format for data, while a tf.data.Dataset
is a pipeline that can define transformations, filtering, repeating, batching, shuffling, and more. That said, once you serialize then deserialize your tf.data.Dataset
, it will be very similar to a TFRecord
in that it will just be the end result of that pipeline. At least, I haven't found a way to serialize/deserialize in a way that doesn't result in a "compute" of the tf.data.Dataset
. But I'm really not expert with this, so I would not be surprised if I have this wrong.
- Will tf.data.experimental.TFRecordWriter work for any tf.data.Dataset irrespective of calling the methods?
If I understand this correctly the flow would go
tf.data.Dataset
tf.data.experimental.TFRecordWriter
tf.data.TFRecordDataset
tf.data.Dataset
I think this would work, but it would have the same effect as using tf.data.Dataset.save
and tf.data.Dataset.load
(or tf.data.experimental.save
and tf.data.experimental.load
pre Tensorflow 2.10) in that you would cause a "compute" on the pipeline graph.
Will tf.data.Dataset.save not work in case the methods have already been called on tf.data.Dataset or does it just slow down the tf.data.Dataset.load operation?
I think using tf.data.Dataset.save
and tf.data.Dataset.load
would work in that the object would be serialized/deserialized and in most cases passing this to model.fit
would work as expected. But there are some cases where there might be surprises. Say, for example, that your tf.data.Dataset
pipeline started with a list of image files in S3, decoded the images and loaded into memory, shuffled, batched, then did some image augmentation (on that shuffled batch). When this dataset is passed to model.train
, each epoch of the training would see differently augmented data. In the case of serializing and deserializing the dataset, each epoch would see the same data.
I don't know if there is a way around this -- I haven't had the need to serialize/deserialize a dataset before using it, so I've really only researched it while thinking about using it with Flyte.
Serializing/Deserializing a TFRecord
is a lot more straightforward because there is no state besides the data.
do people typically subclass tf.data.Dataset to define their own dataset classes, or is tf.data.Dataset the primary class they use?
I have only used tf.data.Dataset
but I know there are several included subclasses such as TextLineDataset and CsvDataset. I would expect that each of these classes would support save
, load
, and snapshot
. I'm not sure if subclassing tf.data.Dataset
is a common technique.
In the case of serializing and deserializing the dataset, each epoch would see the same data.
Um, gotcha. I think that's expected cause the compute is serialized, too. And I think that's okay as long as we let the user know about it and the compute is supported by serialization.
The tf.data.Dataset is a very different from a TFRecord.
I agree. I was referring to handling tf.data.Dataset
with a TFRecordFile
Flyte type.
@task
def produce_record(...) -> TFRecordFile:
return tf.data.Dataset(...)
Hello 👋, This issue has been inactive for over 9 months. To help maintain a clean and focused backlog, we'll be marking this issue as stale and will close the issue if we detect no activity in the next 7 days. Thank you for your contribution and understanding! 🙏
Hello 👋, This issue has been inactive for over 9 months and hasn't received any updates since it was marked as stale. We'll be closing this issue for now, but if you believe this issue is still relevant, please feel free to reopen it. Thank you for your contribution and understanding! 🙏
Hello 👋, this issue has been inactive for over 9 months. To help maintain a clean and focused backlog, we'll be marking this issue as stale and will engage on it to decide if it is still applicable. Thank you for your contribution and understanding! 🙏
In the documentation for Distributed Tensorflow Training some tasks use tf.data.Dataset
as input and/or output. Does that mean that this now works? I see the class TensorFlowRecordFileTransformer(TypeTransformer[TFRecordFile])
, but found nothing for tf.data.Dataset
.
@nightscape the type isn't available yet, so the data will be pickled.
@nightscape , contributions are welcome. Here's some documentation on how to extend Flyte's types to cover custom types.
Motivation: Why do you think this is important?
The
tf.data.Dataset
object encapsulates data as well as a preprocessing pipeline. It can be used in modelfit
,predict
, andevaluate
methods. It is widely used in Tensorflow tutorials and documentation and is considered a best practice when creating pipelines that saturate GPU resources.Goal: What should the final outcome look like, ideally?
Flyte tasks should be able to pass
tf.data.Dataset
objects as parameters and accept them as return types.Describe alternatives you've considered
There are caveats to passing
tf.data.Dataset
objects between tasks. Since atf.data.Dataset
object can have steps in the pipelines that use local Python functions (e.g., amap
orfilter
step), there doesn't seem to be a way to serialize the object without effectively "computing" the graph pipeline. There are times this could be beneficial (doing an expensive preprocessing pipeline once can free up the CPU during training) but this could also be confusing to the Flyte end user.So while adding a type transformer for
tf.data.Dataset
is certainly possible, it's still a good question if Flyte should actually support it given all the caveats. The alternative to consider here is to not supporttf.data.Dataset
. This seems like a question for the core Flyte team.Propose: Link/Inline OR Additional context
There are at least three main ways to serialize/deserialize
tf.data.Dataset
objects.These are probably in order of least complex to most complex. But determining the method of serialization/deserialization is an open question.
Some additional links:
tf.data.Dataset
as a deep copy without having the side-effect of "computing" the pipeline.Are you sure this issue hasn't been raised already?
Have you read the Code of Conduct?