pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.47k stars 477 forks source link

Can the conditions for judging whether the Tensor is synchronized in torch xla need to be improved? #5701

Open ckfgihub opened 1 year ago

ckfgihub commented 1 year ago

Can the conditions for judging whether the Tensor is synchronized in torch xla need to be improved? Add the condition : !torch::lazy::getBackend()->GetComputationDataFromNode(tensors[i]->GetIrValue().node.get())

torch_xla/csrc/xla_graph_executor.cpp

======》 origin code ==== XLAGraphExecutor::SyncTensorCollection XLAGraphExecutor::CollectSyncTensors( const std::vector& tensors, const SyncTensorsConfig& config) { // ..... for (size_t i = 0; i < tensors.size(); ++i) { if (tensor_ids.insert(tensors[i]->GetUniqueId()).second && // A tensor's xla_data might not be up to date if there is a view // associated with it. Make sure to sync those tensors here too. (tensors[i]->CurrentDataHandle() == nullptr || =======================> here ------------------> modify (tensors[i]->data()->view != nullptr && !tensors[i]->data()->view->IsUpToDate()))) {

========》 modify after =====> XLAGraphExecutor::SyncTensorCollection XLAGraphExecutor::CollectSyncTensors( const std::vector& tensors, const SyncTensorsConfig& config) { for (size_t i = 0; i < tensors.size(); ++i) { if (tensor_ids.insert(tensors[i]->GetUniqueId()).second && // A tensor's xla_data might not be up to date if there is a view // associated with it. Make sure to sync those tensors here too. ((tensors[i]->CurrentDataHandle() == nullptr && !torch::lazy::getBackend()->GetComputationDataFromNode(tensors[i]->GetIrValue().node.get())) || (tensors[i]->data()->view != nullptr && !tensors[i]->data()->view->IsUpToDate()))) {

JackCaoG commented 1 year ago

pretty much you are saying we should check if IR node is a DeviceData and whether it has a data. Did you run into a corner case that this change is needed? Can you give me an example?

ckfgihub commented 11 months ago

pretty much you are saying we should check if IR node is a DeviceData and whether it has a data. Did you run into a corner case that this change is needed? Can you give me an example?

I think this graph(MHLO graph below) is useless and a waste of compilation and run time. So the optimization is proposed, If I add Optimization code above, The following MHLO graph will not be generated . My example code is as follows:

source code ==》 import torch import torch.nn as nn import torch.optim as optim import torch_xla import torch_xla.core.xla_model as xm import os from torch_xla.utils.utils import get_free_tcp_ports from torch_xla.amp import syncfree class SimpleModel(nn.Module): def init(self, input_size, output_size): super(SimpleModel, self).init() self.linear = nn.Linear(input_size, output_size) def forward(self, x): return self.linear(x) input_size = 10 output_size = 5 device = xm.xla_device() model = SimpleModel(input_size, output_size).to(device) xm.mark_step()

mhlo graph ==》 module @SyncTensorsGraph.6 { func.func @main(%arg0: tensor<5x10xf32>, %arg1: tensor<5xf32>, %arg2: tensor<1x10xf32>, %arg3: tensor<1x5xf32>) -> tuple<tensor<5x10xf32>, tensor<5xf32>, tensor<1x10xf32>, tensor<1x5xf32>> { %0 = "mhlo.tuple"(%arg0, %arg1, %arg2, %arg3) {xla_shape = "(f32[5,10]{1,0}, f32[5]{0}, f32[1,10]{1,0}, f32[1,5]{1,0})"} : (tensor<5x10xf32>, tensor<5xf32>, tensor<1x10xf32>, tensor<1x5xf32>) -> tuple<tensor<5x10xf32>, tensor<5xf32>, tensor<1x10xf32>, tensor<1x5xf32>> return %0 : tuple<tensor<5x10xf32>, tensor<5xf32>, tensor<1x10xf32>, tensor<1x5xf32>> } }

JackCaoG commented 11 months ago

vaguely remember I tried to fix this issue a while back by not syncing on the DeviceData IR and manually replace them with the underlying Data(linked with the DeviceData Node). I think I run into a couple issues then I kind of stop. Feel free to open a pr and see if any test breaks.