Closed rainwoodman closed 1 year ago
Hi! The review meeting of this RFC was conducted a while ago and the draft was approved by reviewers in February. Shall we merge this PR such that I can link to the more proper looking artifact when referring it in discussions?
Please take a look and let me know of any additional process. Thanks!
Thanks for the merge!
Objective
The new TensorFlow Distribution API extends the TensorFlow Core API, such as tf.Variable, tf.function, and tf.Module with distributed tensor computation.
The low level component of the Distribution API, built with DTensor, provides a uniform SPMD semantics across CPU, GPU, and TPU device types. DTensor is an intrinsic part of TensorFlow that defines a representation of distributed Tensors with Mesh and Layout data structures. Users and high level libraries (such as Keras) can depend on Mesh and Layout just as other components of the TensorFlow low level API. An initial experimental implementation is covered here (on TensorFlow.org): DTensor Concepts DTensor ML Tutorial DTensor Keras Tutorial
This RFC defines the integration between TensorFlow and DTensor, the low level of TensorFlow's Next generation Distribution API.
DTensor defines a uniform and generic API for composing distributed TensorFlow programs for accelerator types supported by TensorFlow. Common distribution patterns in machine learning, including data and model parallelism, spatial partitioning, and pipelining can all be expressed with primitives offered in this RFC.
A very basic form of interoperability with other ML frameworks, such as JAX is also supported in the API described in this RFC.
This document also demonstrates a potential path for integration with the Keras modeling primitives in the form of DTensorStrategy, a new subclass of
tf.distribute.Strategy
.