pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.49k stars 483 forks source link

[LTC] Use torch::lazy::shape in the XLATensor #3725

Open JackCaoG opened 2 years ago

JackCaoG commented 2 years ago

🚀 Feature

Currently PyTorch/XLA uses xla::shape all over the place. Common use of xla::shape would be to get the number of elements of a tensor, compare shape equally between two tensors, check whether reshape is valid, check whether a cast is needed..... Going forward we want to use torch::lazy::shape in the XLATensor and IR level, and only used xla::shape in the XLAOP and runtime level.

Motivation

This is part of the LTC migration that is already happening. Upstream LTC does not know xla::shape and only uses torch::lazy::shape. The biggest difference between torch::lazy::shape and xla::shape is

  1. torch::lazy::shape does not have layout
  2. torch::lazy::shape does not have dynamic information(which dimension is dynamic) yet.

For 1 I think we can delay the layout assignment until the HLO lowering and data uploading. For 2 I think @Krovatkin and meta team have a plan for this.

If we can achieve this goal this will make XLATensor -> LazyTensor inheritance much easier and will unblock the dynamic shape works.

Steps

  1. Manually call shape function(or use meta tensor function) that will calculate the output torch::lazy::shape given input torch::lazy::shape. (We also need to make sure shape cache is in place so we don't keep calling shape functions)
  2. Manually update all IR classes constructor to take std::vector<torch::lazy::Shape>&& shapes
  3. Move the layout assignment to a separate code path and verify the 1 to 1 mapping between xla::shape and torch::lazy::shape
  4. Make sure the utility like XLA_USE_BF16 which keeps the at::tensor to f32 but underlying xla::shape to bf 16 still works.
  5. Replace XLATensor level xla::shape with torch::lazy::shape
  6. Delete all xla::Shape inference function like https://github.com/pytorch/xla/blob/master/torch_xla/csrc/ops/adaptive_avg_pool2d.cpp#L11

1 and 2 is already happening through Codegen work, but we don't need to be blocked by that. We can manually add shape function to achieve the same thing.

One thing to note that we will still get the output xla::shape when we do the op lowering. I don't know if we will want to set the xla::shape back to IR(I doubt that will be very helpful).

@Krovatkin @miladm @wonjoolee95 @wconstab

JackCaoG commented 2 years ago

One thing I noticed just now is that there are some pytorch/xla specified op that there is no meta tensor or shape computation support yet.

One example would be all of cc ops like all_reduce, all_gather. These ops has IR but there is no corresponding pytorch ops yet. I guess we need to implement torch::lazy::shape function for these ops.

JackCaoG commented 2 years ago

Another quick take away is that adding the shape support is actually quite a bit of redundant code copy paste. I think I will change my plan to only manully insert shape functions to those ops that can not be easily generated. One of the example would be _adaptive_avg_pool2d_backward where we have to conditionally fall back.