flyteorg / flyte

Scalable and flexible workflow orchestration platform that seamlessly unifies data, ML and analytics stacks.
https://flyte.org
Apache License 2.0
5.76k stars 656 forks source link

[Plugin] TypeTransformer for TensorFlow tf.data.Dataset #3038

Open dennisobrien opened 2 years ago

dennisobrien commented 2 years ago

Motivation: Why do you think this is important?

The tf.data.Dataset object encapsulates data as well as a preprocessing pipeline. It can be used in model fit, predict, and evaluate methods. It is widely used in Tensorflow tutorials and documentation and is considered a best practice when creating pipelines that saturate GPU resources.

Goal: What should the final outcome look like, ideally?

Flyte tasks should be able to pass tf.data.Dataset objects as parameters and accept them as return types.

Describe alternatives you've considered

There are caveats to passing tf.data.Dataset objects between tasks. Since a tf.data.Dataset object can have steps in the pipelines that use local Python functions (e.g., a map or filter step), there doesn't seem to be a way to serialize the object without effectively "computing" the graph pipeline. There are times this could be beneficial (doing an expensive preprocessing pipeline once can free up the CPU during training) but this could also be confusing to the Flyte end user.

So while adding a type transformer for tf.data.Dataset is certainly possible, it's still a good question if Flyte should actually support it given all the caveats. The alternative to consider here is to not support tf.data.Dataset. This seems like a question for the core Flyte team.

Propose: Link/Inline OR Additional context

There are at least three main ways to serialize/deserialize tf.data.Dataset objects.

  1. tf.data.Dataset.save and tf.data.Dataset.load.
  2. tf.data.Dataset.snapshot
  3. Iterator checkpointing

These are probably in order of least complex to most complex. But determining the method of serialization/deserialization is an open question.

Some additional links:

Are you sure this issue hasn't been raised already?

Have you read the Code of Conduct?

samhita-alla commented 2 years ago

@dennisobrien, thanks for creating the issue!

Have a few questions:

  1. Do we need to support tf.data.Dataset as a Flyte type if TFRecordFile and TFRecordDirectory Flyte types can handle that beneath? (see https://github.com/flyteorg/flytekit/pull/1240 PR for reference)
  2. Will tf.data.experimental.TFRecordWriter work for any tf.data.Dataset irrespective of calling the methods?
  3. Will tf.data.Dataset.save not work in case the methods have already been called on tf.data.Dataset or does it just slow down the tf.data.Dataset.load operation?

cc: @cosmicBboy

cosmicBboy commented 2 years ago

Curious about the answers to @samhita-alla's questions above.

My additional question:

It would also help @dennisobrien if you could come up with some pseudocode snippets for how tf.data.Dataset would be used at the interface of two Flyte tasks. This would clarify this caveat:

Since a tf.data.Dataset object can have steps in the pipelines that use local Python functions (e.g., a map or filter step), there doesn't seem to be a way to serialize the object without effectively "computing" the graph pipeline

The main thing I'm interested in is how tf.data.Dataset would work in the following type of scenario:

@task
def t1(...) -> tf.data.Dataset:
    dataset = tf.data.Dataset.range(100)
    return (
        dataset
        ... # a bunch of transformations, potentially with local functions
    )

@task
def t2(dataset: tf.data.Dataset):
    ...

Suppose the Flyte TypeTransformer takes the output of t1 (after the bunch of transformations) and automatically calls dataset.save... will this actually work? I.e. will the dataset be serialized with all the transformations applied to it, or will the dataset transforms be somehow serialized with it?

On the flipside, when t2 automatically deserializes the result of that with tf.data.Dataset.load, will materializing the data actually work? Or would the user need to apply the same transformations to the raw data again?

dennisobrien commented 2 years ago

Sorry for the delay in responding.

  1. Do we need to support tf.data.Dataset as a Flyte type if TFRecordFile and TFRecordDirectory Flyte types can handle that beneath?

The tf.data.Dataset is a very different from a TFRecord. The TFRecord is really a binary format for data, while a tf.data.Dataset is a pipeline that can define transformations, filtering, repeating, batching, shuffling, and more. That said, once you serialize then deserialize your tf.data.Dataset, it will be very similar to a TFRecord in that it will just be the end result of that pipeline. At least, I haven't found a way to serialize/deserialize in a way that doesn't result in a "compute" of the tf.data.Dataset. But I'm really not expert with this, so I would not be surprised if I have this wrong.

  1. Will tf.data.experimental.TFRecordWriter work for any tf.data.Dataset irrespective of calling the methods?

If I understand this correctly the flow would go

I think this would work, but it would have the same effect as using tf.data.Dataset.save and tf.data.Dataset.load (or tf.data.experimental.save and tf.data.experimental.load pre Tensorflow 2.10) in that you would cause a "compute" on the pipeline graph.

Will tf.data.Dataset.save not work in case the methods have already been called on tf.data.Dataset or does it just slow down the tf.data.Dataset.load operation?

I think using tf.data.Dataset.save and tf.data.Dataset.load would work in that the object would be serialized/deserialized and in most cases passing this to model.fit would work as expected. But there are some cases where there might be surprises. Say, for example, that your tf.data.Dataset pipeline started with a list of image files in S3, decoded the images and loaded into memory, shuffled, batched, then did some image augmentation (on that shuffled batch). When this dataset is passed to model.train, each epoch of the training would see differently augmented data. In the case of serializing and deserializing the dataset, each epoch would see the same data.

I don't know if there is a way around this -- I haven't had the need to serialize/deserialize a dataset before using it, so I've really only researched it while thinking about using it with Flyte.

Serializing/Deserializing a TFRecord is a lot more straightforward because there is no state besides the data.

do people typically subclass tf.data.Dataset to define their own dataset classes, or is tf.data.Dataset the primary class they use?

I have only used tf.data.Dataset but I know there are several included subclasses such as TextLineDataset and CsvDataset. I would expect that each of these classes would support save, load, and snapshot. I'm not sure if subclassing tf.data.Dataset is a common technique.

samhita-alla commented 2 years ago

In the case of serializing and deserializing the dataset, each epoch would see the same data.

Um, gotcha. I think that's expected cause the compute is serialized, too. And I think that's okay as long as we let the user know about it and the compute is supported by serialization.

The tf.data.Dataset is a very different from a TFRecord.

I agree. I was referring to handling tf.data.Dataset with a TFRecordFile Flyte type.

@task
def produce_record(...) -> TFRecordFile:
    return tf.data.Dataset(...)
github-actions[bot] commented 1 year ago

Hello 👋, This issue has been inactive for over 9 months. To help maintain a clean and focused backlog, we'll be marking this issue as stale and will close the issue if we detect no activity in the next 7 days. Thank you for your contribution and understanding! 🙏

github-actions[bot] commented 1 year ago

Hello 👋, This issue has been inactive for over 9 months and hasn't received any updates since it was marked as stale. We'll be closing this issue for now, but if you believe this issue is still relevant, please feel free to reopen it. Thank you for your contribution and understanding! 🙏

github-actions[bot] commented 3 months ago

Hello 👋, this issue has been inactive for over 9 months. To help maintain a clean and focused backlog, we'll be marking this issue as stale and will engage on it to decide if it is still applicable. Thank you for your contribution and understanding! 🙏

nightscape commented 2 months ago

In the documentation for Distributed Tensorflow Training some tasks use tf.data.Dataset as input and/or output. Does that mean that this now works? I see the class TensorFlowRecordFileTransformer(TypeTransformer[TFRecordFile]), but found nothing for tf.data.Dataset.

samhita-alla commented 2 months ago

@nightscape the type isn't available yet, so the data will be pickled.