Open miladm opened 2 years ago
The first question for me is if nonzero
assigns the expected SSA dynamic bits to the output tensor?
The next question for me is if XLATensor::CreateFrom
propagates the IR node dynamic properties as expected?
Assuming the dynamic bit for the first dimension of nonzero
is not set by the upstream, here are my thoughts:
In the PT/XLA IR level, nonzero
output object goes through this call sequence before getting wrapped into an at::Tensor
:
XLATensor::CreateFrom(torch::lazy::Value ir_value...)
> XLATensor::Create(torch::lazy::Value ir_value...)
> Data(torch::lazy::Value ir_value...)
to get stored as torch::lazy::Value ir_value
.
Option 1:
For all dynamic ops, the upstream torch/LTC layer sets torch::lazy::shape.is_symbolic_
vector bits inside the ir_value
node.
Option 2:
For all dynamic ops, the PyTorch/XLA IR layer sets torch::lazy::shape.is_symbolic_
vector bits inside the ir_value
node. Below would be the hypothetical code in tensor_methods.cpp
:
XLATensorPtr XLATensor::nonzero(const XLATensorPtr& input) {
torch::lazy::NodePtr node =
torch::lazy::MakeNode<NonZero>(input->GetIrValue());
torch::lazy::Value ir_value = torch::lazy::Value(node, 0);
ir_value.shape().set_symbolic(0); /* New torch::lazy::shape API to set dynamic bits */
return input->CreateFrom(ir_value, at::ScalarType::Long);
}
My Takeaway: I think it is better design to set the dynamic properties of an op in the upstream level. Control of dynamic bits by the downstream doesn't seem reasonable. @Krovatkin wdyt?
CC @JackCaoG @Gamrix
@miladm I believe it will be Option 2. We typically won't go through the LTC layer when tracing and generating IR nodes. XLA directly takes over and we end up in XLATensor::nonzero
which will need to call torch::lazy::
APIs (e.g. with_symbolic_dims
).
Since we expect most ops to be code-generated, these calls will be inserted for you automatically.
Thanks @Krovatkin!
I wrote some code to update dynamic properties of nonzero
IR Value
in PT/XLA. Turns out the problem I described earlier persists. Refer to this PR commit for details. We can discuss this in our upcoming offline chat.
In short, this line returns false
even through this line returns true
. I expect both to return true
.
Running
xla_y[0].is_symbolic()
aftertorch::Tensor xla_y = torch::nonzero(xla_x);
returnsfalse
even though we expecttrue
output.I assume the upstream dynamic shape analysis assigns the correct dynamic property to each dimension of
nonzero
output tensor. @Krovatkin @Gamrix wdyt?The full test code: