Open ckfgihub opened 1 year ago
pretty much you are saying we should check if IR node is a DeviceData and whether it has a data. Did you run into a corner case that this change is needed? Can you give me an example?
pretty much you are saying we should check if IR node is a DeviceData and whether it has a data. Did you run into a corner case that this change is needed? Can you give me an example?
I think this graph(MHLO graph below) is useless and a waste of compilation and run time. So the optimization is proposed, If I add Optimization code above, The following MHLO graph will not be generated . My example code is as follows:
source code ==》 import torch import torch.nn as nn import torch.optim as optim import torch_xla import torch_xla.core.xla_model as xm import os from torch_xla.utils.utils import get_free_tcp_ports from torch_xla.amp import syncfree class SimpleModel(nn.Module): def init(self, input_size, output_size): super(SimpleModel, self).init() self.linear = nn.Linear(input_size, output_size) def forward(self, x): return self.linear(x) input_size = 10 output_size = 5 device = xm.xla_device() model = SimpleModel(input_size, output_size).to(device) xm.mark_step()
mhlo graph ==》 module @SyncTensorsGraph.6 { func.func @main(%arg0: tensor<5x10xf32>, %arg1: tensor<5xf32>, %arg2: tensor<1x10xf32>, %arg3: tensor<1x5xf32>) -> tuple<tensor<5x10xf32>, tensor<5xf32>, tensor<1x10xf32>, tensor<1x5xf32>> { %0 = "mhlo.tuple"(%arg0, %arg1, %arg2, %arg3) {xla_shape = "(f32[5,10]{1,0}, f32[5]{0}, f32[1,10]{1,0}, f32[1,5]{1,0})"} : (tensor<5x10xf32>, tensor<5xf32>, tensor<1x10xf32>, tensor<1x5xf32>) -> tuple<tensor<5x10xf32>, tensor<5xf32>, tensor<1x10xf32>, tensor<1x5xf32>> return %0 : tuple<tensor<5x10xf32>, tensor<5xf32>, tensor<1x10xf32>, tensor<1x5xf32>> } }
vaguely remember I tried to fix this issue a while back by not syncing on the DeviceData IR and manually replace them with the underlying Data(linked with the DeviceData Node). I think I run into a couple issues then I kind of stop. Feel free to open a pr and see if any test breaks.
Can the conditions for judging whether the Tensor is synchronized in torch xla need to be improved? Add the condition : !torch::lazy::getBackend()->GetComputationDataFromNode(tensors[i]->GetIrValue().node.get())
torch_xla/csrc/xla_graph_executor.cpp
======》 origin code ==== XLAGraphExecutor::SyncTensorCollection XLAGraphExecutor::CollectSyncTensors( const std::vector& tensors, const SyncTensorsConfig& config) {
// .....
for (size_t i = 0; i < tensors.size(); ++i) {
if (tensor_ids.insert(tensors[i]->GetUniqueId()).second &&
// A tensor's xla_data might not be up to date if there is a view
// associated with it. Make sure to sync those tensors here too.
(tensors[i]->CurrentDataHandle() == nullptr || =======================> here ------------------> modify
(tensors[i]->data()->view != nullptr &&
!tensors[i]->data()->view->IsUpToDate()))) {
========》 modify after =====> XLAGraphExecutor::SyncTensorCollection XLAGraphExecutor::CollectSyncTensors( const std::vector& tensors, const SyncTensorsConfig& config) {
for (size_t i = 0; i < tensors.size(); ++i) {
if (tensor_ids.insert(tensors[i]->GetUniqueId()).second &&
// A tensor's xla_data might not be up to date if there is a view
// associated with it. Make sure to sync those tensors here too.
((tensors[i]->CurrentDataHandle() == nullptr && !torch::lazy::getBackend()->GetComputationDataFromNode(tensors[i]->GetIrValue().node.get())) ||
(tensors[i]->data()->view != nullptr &&
!tensors[i]->data()->view->IsUpToDate()))) {