Open JackCaoG opened 2 years ago
One thing I noticed just now is that there are some pytorch/xla specified op that there is no meta tensor or shape computation support yet.
One example would be all of cc ops like all_reduce
, all_gather
. These ops has IR but there is no corresponding pytorch ops yet. I guess we need to implement torch::lazy::shape
function for these ops.
Another quick take away is that adding the shape support is actually quite a bit of redundant code copy paste. I think I will change my plan to only manully insert shape functions to those ops that can not be easily generated. One of the example would be _adaptive_avg_pool2d_backward
where we have to conditionally fall back.
🚀 Feature
Currently PyTorch/XLA uses
xla::shape
all over the place. Common use ofxla::shape
would be to get the number of elements of a tensor, compare shape equally between two tensors, check whether reshape is valid, check whether a cast is needed..... Going forward we want to usetorch::lazy::shape
in theXLATensor
andIR
level, and only usedxla::shape
in theXLAOP
and runtime level.Motivation
This is part of the LTC migration that is already happening. Upstream LTC does not know
xla::shape
and only usestorch::lazy::shape
. The biggest difference betweentorch::lazy::shape
andxla::shape
istorch::lazy::shape
does not have layouttorch::lazy::shape
does not have dynamic information(which dimension is dynamic) yet.For 1 I think we can delay the layout assignment until the HLO lowering and data uploading. For 2 I think @Krovatkin and meta team have a plan for this.
If we can achieve this goal this will make
XLATensor
->LazyTensor
inheritance much easier and will unblock the dynamic shape works.Steps
torch::lazy::shape
given inputtorch::lazy::shape
. (We also need to make sure shape cache is in place so we don't keep calling shape functions)std::vector<torch::lazy::Shape>&& shapes
xla::shape
andtorch::lazy::shape
XLA_USE_BF16
which keeps theat::tensor
to f32 but underlyingxla::shape
to bf 16 still works.xla::shape
withtorch::lazy::shape
xla::Shape
inference function like https://github.com/pytorch/xla/blob/master/torch_xla/csrc/ops/adaptive_avg_pool2d.cpp#L111 and 2 is already happening through Codegen work, but we don't need to be blocked by that. We can manually add shape function to achieve the same thing.
One thing to note that we will still get the output
xla::shape
when we do the op lowering. I don't know if we will want to set thexla::shape
back to IR(I doubt that will be very helpful).@Krovatkin @miladm @wonjoolee95 @wconstab