pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.47k stars 476 forks source link

`nonzero` does not return the expected `symint` dimensions #3706

Open miladm opened 2 years ago

miladm commented 2 years ago

Running xla_y[0].is_symbolic() after torch::Tensor xla_y = torch::nonzero(xla_x); returns false even though we expect true output.

I assume the upstream dynamic shape analysis assigns the correct dynamic property to each dimension of nonzero output tensor. @Krovatkin @Gamrix wdyt?

The full test code:

TEST_F(AtenXlaTensorTest, TestExpandSymInt) {
  torch::Tensor x = torch::rand({5});
  torch::Tensor y = torch::nonzero(x);
  int64_t y0_size = y.sizes()[0];
  torch::Tensor a = torch::rand({3, 4}, torch::TensorOptions(torch::kFloat));
  torch::Tensor b = a.expand({y0_size, 3, 4}, /*implicit=*/false);

  ForEachDevice([&](const torch::Device& device) {
    torch::Tensor xla_x = CopyToDevice(x, device);
    torch::Tensor xla_y = torch::nonzero(xla_x);
    c10::SymInt xla_y0_size = xla_y.sym_sizes()[0];
    torch::Tensor xla_a = CopyToDevice(a, device);
    torch::Tensor xla_b = xla_a.expand_symint(
        c10::SymIntArrayRef({xla_y0_size, c10::SymInt(3), c10::SymInt(4)}),
        /*implicit=*/false);
    AllClose(b, xla_b);
    ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
    ExpectCounterChanged("xla::expand_symint", cpp_test::GetIgnoredCounters());
  });
}
miladm commented 2 years ago

The first question for me is if nonzero assigns the expected SSA dynamic bits to the output tensor?

miladm commented 2 years ago

The next question for me is if XLATensor::CreateFrom propagates the IR node dynamic properties as expected?

miladm commented 2 years ago

Assuming the dynamic bit for the first dimension of nonzero is not set by the upstream, here are my thoughts:

In the PT/XLA IR level, nonzero output object goes through this call sequence before getting wrapped into an at::Tensor:

XLATensor::CreateFrom(torch::lazy::Value ir_value...) > XLATensor::Create(torch::lazy::Value ir_value...) > Data(torch::lazy::Value ir_value...) to get stored as torch::lazy::Value ir_value.

Option 1: For all dynamic ops, the upstream torch/LTC layer sets torch::lazy::shape.is_symbolic_ vector bits inside the ir_value node.

Option 2: For all dynamic ops, the PyTorch/XLA IR layer sets torch::lazy::shape.is_symbolic_ vector bits inside the ir_value node. Below would be the hypothetical code in tensor_methods.cpp:

XLATensorPtr XLATensor::nonzero(const XLATensorPtr& input) {
  torch::lazy::NodePtr node =
      torch::lazy::MakeNode<NonZero>(input->GetIrValue());
  torch::lazy::Value ir_value = torch::lazy::Value(node, 0);
  ir_value.shape().set_symbolic(0);    /* New torch::lazy::shape API to set dynamic bits */
  return input->CreateFrom(ir_value, at::ScalarType::Long);
}

My Takeaway: I think it is better design to set the dynamic properties of an op in the upstream level. Control of dynamic bits by the downstream doesn't seem reasonable. @Krovatkin wdyt?

CC @JackCaoG @Gamrix

Krovatkin commented 2 years ago

@miladm I believe it will be Option 2. We typically won't go through the LTC layer when tracing and generating IR nodes. XLA directly takes over and we end up in XLATensor::nonzero which will need to call torch::lazy:: APIs (e.g. with_symbolic_dims).

Since we expect most ops to be code-generated, these calls will be inserted for you automatically.

miladm commented 2 years ago

Thanks @Krovatkin!

I wrote some code to update dynamic properties of nonzero IR Value in PT/XLA. Turns out the problem I described earlier persists. Refer to this PR commit for details. We can discuss this in our upcoming offline chat.

In short, this line returns false even through this line returns true. I expect both to return true.