pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.46k stars 469 forks source link

[PRD] Bounded Dynamic Shape Design for `PyTorch` & `PyTorch/XLA` #3884

Open miladm opened 2 years ago

miladm commented 2 years ago

đź“š Bounded Dynamic Shape Design

Contributors

List of contributors in chronological order:

Objective

PyTorch/XLA adopts Lazy Tensor design, a runtime tracing mechanism that gives the illusion of eager execution to the user while enabling real time code optimization under the hood. A critical performance requirement of PyTorch/XLA is maintaining a low ratio of program trace compilation to rerun. To harness this property, one requirement of PyTorch/XLA is to have limited program tensor shape changes. This requires the operations performed and input shapes, intermediate output shapes, and output shapes to stay the same across training steps. Programs with Dynamic Shape properties are in conflict with this requirement and lead to significant recompilation and thereby runtime regression. Computer vision detection / segmentation applications like Mask RCNN and SSD are popular deep learning algorithms that use dynamic shape. Bounded Dynamic Shape is a top customer-critical feature that has been missing in PyTorch and PyTorch/XLA.

An existing approach to solve dynamic shape is that developers manually pad dynamic tensors by inferring the static upper bounds of each tensor and propagating that property to consumer operations. This process is inconvenient and time consuming. Bounded Dynamic Shape automatically identifies static shape boundaries, propagates shapes, and dynamically pads tensors for the developer. As a result, ML developers focus their valuable time on modeling design and the compiler spends fewer cycles recompiling program traces at runtime.

TPU requires tensor shapes to be determined at compile time. Thus, PyTorch/XLA focuses on supporting bounded dynamic shapes. Supporting unbounded dynamic shape enables PyTorch/XLA to have an even broader support for other accelerator backends like GPU; this feature is out of scope of this design phase. This document is a short summary of the FULL DESIGN.

Background

Bounded Dynamic Shape Definition

Dynamic Shape refers to the variable nature of a tensor shape where its shape depends on the value of another upstream tensor. Specifically, this relationship is governed by an operation that produces an output tensor with its shape dependent on the value of its input tensor(s). Below is a simple example showing how the shape of out_tensor is a function of data stored in in_tensor, the input to torch.nonzero. As we will discuss later, the key property of out_tensor is that its upper bound dimensions can never exceed that of in_tensor shape, i.e. (5, 5). An op like torch.nonzero is called a “dynamic op”, a tensor like in_tensor is called a “static tensor”, and a tensor like out_tensor is called a “dynamic tensor”.

import torch, torch_xla
in_tensor  = torch.randint(low=0, high=2, size=(5,5), device='xla:0')
out_tensor = torch.nonzero(in_tensor)

What is a Dynamic Zone?

Programs go through what we define as the “Dynamic Zone”. A dynamic zone starts with an operation with {static input tensor, dynamic output tensor} and terminates with an operation with {dynamic input tensor, static output tensor}. Below is a code example that illustrates this property. In this example t2 and t3 are dynamic tensors while t1, t4, t5 are static.

import torch, torch_xla
t1  = torch.randint(low=0, high=2, size=(5,5), device='xla:0')
t2 = torch.nonzero(t1)                                     <--- Dynamic Zone Start
t3 = torch.pow(t2, 2)
t4 = t3.sum()                                              <--- Dynamic Zone End
t5 = t4 + 2

The following table summarized the conditions that start and end a dynamic zone.

Dynamic Zone Start Dynamic Zone End
Dynamic ops [in scope]
masked_select
unique
* nonzero
Dynamic input tensors [in scope for phase 2]
Graph termination (e.g. print, mark_step)
Reduction ops (e.g. sum)
Conditional ops (e.g. if)
Ops with non-tensor return type (e.g. unbind)

SOTA for Bounded Dynamic Shape

Bounded Dynamic shape was introduced by TensorFlow/XLA. In this design, XLA introduces two key operation semantics to set and get dynamic dimensions of a dynamic tensor at compile time: xla::SetDimensionSize and xla::GetDimensionSize. The former sets the dynamic dimensions of a leaf IR node and provides support for propagation of dynamic properties to subsequent nodes. The latter materializes dynamic dimensions of an op. Dynamic shape in XLA is built upon several components: Dynamic Shape representation in HLO, Dynamic Padder, Get/SetDimensionSize API, Shape Inference, Value Inference, TF2XLA Bridge Kernels, and Automatic Bucketing. In this initiative, PyTorch adopts Dynamic Shape via PyTorch Lazy Tensor Core. As a result, this document discusses the dynamic shape design for PyTorch and PyTorch/XLA. PyTorch/XLA will leverage the underlying XLA APIs to support dynamism.

PyTorch Building Blocks to Enable Dynamic Shape Several underlying building blocks are required to realize Bounded Dynamic Shape in PyTorch. Inspired by the original PyTorch/XLA frontend design, Lazy Tensor Core (LTC) is the enabling components to deliver the Lazy Tensor execution to virtually any backend accelerator. As shown in Figure 1, LTC is designed to enable flexible backend integration with PyTorch/XLA and TorchScript; these backend layers support dedicated accelerator kernels.

Figure 1: The PyTorch ecosystem necessary to enable Bounded Dynamic Shape for Lazy Tensor.

Overview

To support Bounded Dynamic Shape in PyTorch/XLA, we require several technologies across the stack. Figure 2 illustrates the technologies we require at the PyTorch, PyTorch/XLA, and XLA. The only features readily available to use are those at the XLA layer. The remaining are to be implemented in this project. The detailed design behind each block is discussed in the FULL DESIGN document.

Figure 2: The blocks required to support Bounded Dynamic Shape in PyTorch/LTC and PyTorch/XLA.

Bounded Dynamic Shape Design API

Dynamic Shape demands a deep code stack support from PyTorch to PyTorch/XLA to Tensorflow/XLA. Tensorflow/XLA supports bounded dynamic shapes. In this work, we focus our attention on the two former layers (i.e. PyTorch, PyTorch/XLA).

A bounded dynamic shape design that abstracts code analysis from the user requires several key components. First, it must identify tensors with dynamic dimensions; second, it must use shape inference to determine upper bound dimensions of dynamic tensors through the IR / HLO graph; third, it requires value inference to determine upper/lower bound dimensions for dynamic tensors with data-dependent shape sizes (e.g. tensor.arange()); finally, it requires dynamic padder to appropriately pad tensors prior to execution. The PyTorch eager code tracing model enables propagating a bulk of dynamic shape context.

To address the first and second steps above, we need to infer two shapes for each dynamic tensor: (1) static upper bounds, (2) the actual tensor bounds. The first requirement is handled by the upstream PyTorch eager shape inference step during tracing. The second requirement is handled by the downstream xla::GetDimensionSize API.

To address the first and third step above, we require propagating dynamic shape across dynamic tensors with data-dependent shape sizes. In other words, we require dynamic shape propagation across the size() operation. The Size API Abstraction section discusses the design.

Size API Abstraction

As much as possible, we intend to abstract away dynamic shape details from the user. Our design solution is called Boxed Dynamic Shape, a reimplementation of the size() API within the PyTorch C++ and Python levels; visit Boxed Dynamic Shapes for the in-depth design.

Boxed Dynamic Shape

A simple and effective method to describe the size() expression is an IR node at C++ and Python levels via mapping this operation to a SymbolicInt abstraction. Conceptually, SymInt is a union of either a concrete integral dimension value or a size expression (i.e. an IR node like aten::add(aten::size(t, 0)*aten::size(t2, 1))), or an index to a side table where size() computation expressions are stored. Below we discuss the key API considerations for this design. For additional details visit, Boxed Dynamic Shapes and review the complete POC PR-new (and PR-old) of Boxed Dynamic Shape.

A basic requirement from the PyTorch C++ API is it must continue to return vectors of int64_t to represent tensor sizes. This simple yet difficult constraint requires the PyTorch tracing system to reuse the same C++ code while reckoning with the existing static shape format. In lieu of rewriting the C++ API more significantly, we propose to smuggle dynamic size information through an int64_t in a “boxed representation” which can be unboxed outside of the restricted API and allows tracers to capture size arithmetic as part of the program IR rather than as specific constants. More specifically, we propose to build a dynamic shapes tracing solution that enables capturing size arithmetic as IR rather than as baked-in constants. PyTorch/XLA adopts and integrates boxed dynamic shapes API.

PyTorch/XLA IR Level Shape Handling via MetaTensors

To integrate the MetaTensor / Structured Kernel implementation, PyTorch/XLA will adopt the Lazy Tensor codegen to acquire the PyTorch eager shape inference retrieval. In the short term, xla::Shape and lazy::Shape will coexist. In the long term, PyTorch/XLA will remove the IR xla::Shape calls and replace them with lazy::Shape calls passed from PyTorch.

User Experience

Our ideal is to make Dynamic Shape have near to no impact on the user. Despite best intentions, this technology has minimal impact on user experience. Outside a dynamic zone, a size() call returns a concrete int. Inside a dynamic zone, a size() call returns SymInt. The following three scenarios explain how a size() call behaves throughout a program.

  1. When a dynamic tensor is passed to a downstream op (e.g. t2 passed to torch.pow in above example), under the hood, the op receives a symbolic int object silently.
  2. Upon printing the size() of a dynamic tensor, its “upper bounds” are returned as string.
  3. When the true shape of a dynamic tensor is required (e.g. in a conditional statement), the program executes to materialize the actual tensor shape for follow up computation.

Failure Handling Strategy

Supporting dynamic shape will be an incremental journey that goes through updating a significant number of ops in Pytorch and PyTorch/XLA. Throughout this journey, we want to enable the user to understand the extent of support; when a set of ops miss dynamic shape support, the user will be notified via meaningful error messages suggesting the expected behavior is “Not Yet Supported”. PyTorch will integrate a default error message in the LTC codegen to inform the user if an op misses dynamic shape support (github issue).

It is possible that the support level between PyTorch and PyTorch/XLA diverge. This happens when an op has dynamism support in PyTorch but not yet in PyTorch/XLA. PyTorch/XLA will maintain a blacklist of ops that don’t have dynamic shape support, and notifies the user via a meaningful error message (github issue). The same strategy applies to the scenarios discussed in Limitations section.

Testing Plan

Our implementation process begins with supporting simple ops, to supporting simple models, to supporting Mask R-CNN. The support for Mask R-CNN will come in phases; one way to stage the process is via supporting its submodules such as Faster RCNN, RoIAlign, etc. At each step of the way, we test this design on TPU and GPU devices.

Dynamism Beyond PyTorch/LTC & PyTorch/XLA:TPU

Do We Need Unbounded Dynamic Shape?

Bounded dynamic shape has real world applications in numerous deep learning applications. Authors are unaware of any real world deep learning applications that would produce unbounded dynamic shape. Having said that, freeing the model developer from setting tensor bounds is an advantage that trades developer time for compilation time. Whether unbounded dynamism truly offers an advantage to the deep learning community (both modeling and systems) remains an active topic for debate and analysis.

torch.MLIR Adoption of Bounded Dynamic Shape

Today, torch.MLIR, a Google research project, is adopted by a few accelerator vendors for several reasons, one of which being backend dialect flexibility; it supports LINALG, TOSA, MHLO. Today, torch.MLIR supports unbounded dynamic shape. The project finds bounded dynamism out of scope, and quite challenging to integrate. The challenge is torch.MLIR can propagate either static shape or unbounded (i.e. ?) shape dimensions, and NOT upper bounds. However, torch.MLIR supports PyTorch/LTC via its LTC MLIR plugin. Once bounded dynamism is realized in PyTorch/LTC, torch.MLIR community is enabled to support bounded dynamism.

JAX Dynamic Shape Support

JAX plans to support bounded and unbounded dynamism in one of its upcoming releases. This support is inspired by the implementation of Dynamic Shape in Dex. A key difference between JAX and PyTorch/XLA adoption of dynamic shape is that JAX must handle dynamism within loops, something PyTorch/XLA is not concerned with due to its use of JIT tracing across loops.

TorchDynamo Dynamic Shape

TorchDynamo has a Dynamic Shape project that offers a complementary agenda to this design. This project aims to build guards that enable symbolic shape inference at conditional statements, thereby enabling trace generation across conditional blocks. It creates larger graph traces by symbolically inferring tensor shapes, and by determining if the shapes meet the conditional guards put in place for the given trace. Similar to if-else Statements section, TorchDynamo may materialize dynamic tensors at conditionals - dynamic ops would be torch.nonzero or torch.unique. For more, visit the Future Work section.

Limitations

Ragged Tensors

Since ragged Tensors are yet to be supported in XLA, PyTorch/XLA is blocked to offer support for this feature in this phase of development.

Dynamic Shape Across Control Ops

Conditional statements are the termination points for dynamic zones. We will explore extending dynamic shape beyond control statements in future when customer demand is present.

Future Work

In future phases of this project, we would like to bring into scope the following features according to demands from the OSS community and customers. The top-3 items are under consideration for integration in the second phase of this project.

  1. Support python APIs that enable the user to set upper bounds on input tensors
    • This API extends the Dynamic Zone API to input tensors. Otherwise, we expect the technology to be closely aligned with this design.
  2. Support unbonded dynamic shape (for GPU use case)
  3. Address the current design Limitations as discussed earlier
  4. Enable dynamic shape across Conditional Ops as discussed earlier
  5. Build necessary integrations to enable Dynamic Shape with gSPMD
  6. Build necessary integrations to enable Dynamic Shape with TorchDynamo
adoda commented 1 year ago

👍, Is there an update for Bounded Dynamic Shape? When can we run an example ?

JackCaoG commented 1 year ago

@adoda You can follow up using the dynamism tag under https://github.com/pytorch/xla/issues?q=is%3Aissue+is%3Aopen+label%3Adynamism . We have spent most of time to make C++ infra code ready, now we started to surface those infra through python apis.

miladm commented 1 year ago

PyTorch/XLA will soon enable StableHLO support. This RFC from StableHLO describes how dynamism will be supported by the downstream layer.

CC @ezyang @wconstab @smit-hinsu @burmako @shauheen @JackCaoG

burmako commented 1 year ago

Milad, thank you for the introduction! I also wanted to provide some additional context on StableHLO's plans for dynamism. In a nutshell these plans involve: 1) Support for bounded dynamism (already informally exists in StableHLO, needs to be formalized). 2) Support for unbounded dynamism (already informally exists in StableHLO, needs to be formalized). 3) Potential unification.

1) For bounded dynamism, the RFC linked by Milad talks about the design in StableHLO which involves roughly the same functionality as HLO: a) bounds for dimension sizes in tensor types, b) get_dimension_size and set_dimension_size ops that map directly on HLO ops. In StableHLO, we take compatibility very seriously, and we want to keep providing these APIs for frontends like PyTorch/XLA that have investment in HLO-style dynamism.

2) For unbounded dynamism, we're going to have a separate RFC in the near future, and the work on this RFC is tracked in https://github.com/openxla/stablehlo/issues/8. In the unbounded dynamism RFC, we'll be aiming to:

3) In the long term, we would also like to unify bounded and unbounded APIs, and the RFC linked by Milad talks about some initial ideas in this area (see the "(P4) Aspirational: Migration to unbounded dynamism" section). This is some very early thinking, and I'm not yet sure how this will work out - it would be preferable to have just one API, but maybe we'll end up with two.