pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.46k stars 469 forks source link

[torch-xla 2.1 - 2.4] when functionalization is on, there are no aliasing for gradients when using gradient accumulation #7174

Closed jeffhataws closed 2 months ago

jeffhataws commented 4 months ago

🐛 Bug

When functionalization is on (XLA_DISABLE_FUNCTIONALIZATION=0), I see that there are fewer aliased tensors. Jack has a patch to increase the number of aliased tensors https://github.com/pytorch/xla/commit/e3fc03314dab5f44e3ed9ccbba6c15fbca3285cd . However, even though this change helped increase the number of aliased tensor, it seems to still missing aliasing for gradients when gradient accumulation is used.

Using test_train_mp_mnist.py, make the modifications below. I added a mark_step to isolate the gradient accumulation loops.

@ -158,16 +163,19 @@ def train_mnist(flags, **kwargs):
       output = model(data)
       loss = loss_fn(output, target)
       loss.backward()
-      if flags.ddp:
-        optimizer.step()
-      else:
-        xm.optimizer_step(optimizer)
-      tracker.add(flags.batch_size)
-      if step % flags.log_steps == 0:
-        xm.add_step_closure(
-            _train_update,
-            args=(device, step, loss, tracker, epoch, writer),
-            run_async=flags.async_closures)
+
+      if step % 4 == 0:
+          xm.mark_step()
+          if flags.ddp:
+            optimizer.step()
+          else:
+            xm.optimizer_step(optimizer)
+          tracker.add(flags.batch_size)
+          if step % flags.log_steps == 0:
+            xm.add_step_closure(
+                _train_update,
+                args=(device, step, loss, tracker, epoch, writer),
+                run_async=flags.async_closures)

I only see 2 alias even though we expect all the gradient tensors to be aliased:

2024-06-03 21:15:37.676472: I torch_xla/csrc/xla_graph_executor.cpp:1462] Parameter sequence graph hash b8e15ed0391b82171706a34d84ca8ea0
2024-06-03 21:15:37.678822: I torch_xla/csrc/xla_graph_executor.cpp:1299] Aliased paramter 13 with output 4: s64[]
2024-06-03 21:15:37.678862: I torch_xla/csrc/xla_graph_executor.cpp:1299] Aliased paramter 14 with output 5: s64[]
2024-06-03 21:15:37.679222: I torch_xla/csrc/xla_graph_executor.cpp:1397] Compiling IR graph hash b8e15ed0391b82171706a34d84ca8ea0 on device CPU:0 ...

To Reproduce

Steps to reproduce the behavior:

  1. Check out r2.1_aws_neuron branch
  2. Apply a patch from Jack https://github.com/pytorch/xla/commit/e3fc03314dab5f44e3ed9ccbba6c15fbca3285cd
  3. Build/install as in CONTRIBUTION doc
  4. Go into xla/test
  5. Edit test_train_mp_mnist.py and add gradient accumulation loop as above.
  6. Run with TF_CPP_MIN_LOG_LEVEL=0 TF_CPP_VMODULE="xla_graph_executor=6,pjrt_computation_client=5" to see aliasing debugging logs:
    XLA_IR_DEBUG=1 XLA_HLO_DEBUG=1 XLA_SAVE_TENSORS_FMT="hlo" XLA_SAVE_TENSORS_FILE="/tmp/save1.hlo"   TF_CPP_MIN_LOG_LEVEL=0 TF_CPP_VMODULE="xla_graph_executor=6,pjrt_computation_client=5" python test_train_mp_mnist.py |& tee log

Expected behavior

Expect gradients to be aliased

Environment

Additional context

JackCaoG commented 4 months ago

so I tired to test it on nightly, with

diff --git a/examples/train_resnet_base.py b/examples/train_resnet_base.py
index b66780d5c..acd00edff 100644
--- a/examples/train_resnet_base.py
+++ b/examples/train_resnet_base.py
@@ -11,6 +11,8 @@ import torch_xla
 import torchvision
 import torch.optim as optim
 import torch.nn as nn
+import torch_xla.debug.metrics as met
+

 class TrainResNetBase():
@@ -29,7 +31,7 @@ class TrainResNetBase():
         xr.world_size())

     self.device = torch_xla.device()
-    self.train_device_loader = pl.MpDeviceLoader(train_loader, self.device)
+    self.train_device_loader = pl.MpDeviceLoader(train_loader, self.device, batches_per_execution=4)
     self.model = torchvision.models.resnet50().to(self.device)
     self.optimizer = optim.SGD(self.model.parameters(), weight_decay=1e-4)
     self.loss_fn = nn.CrossEntropyLoss()
@@ -51,6 +53,8 @@ class TrainResNetBase():
       loss.backward()
       self.run_optimizer()
       tracker.add(self.batch_size)
+      print('alias count = ')
+      print(met.metric_data("InputOutputAliasCount"))
       if step % 10 == 0:
         xm.add_step_closure(
             self._train_update, args=(step, loss, tracker, epoch))

we can do graident accumulation for every 4 steps for resnet50. When run with

PT_XLA_DEBUG_LEVEL=1 python examples/train_resnet_base.py 

I see

Compilation Analysis: ================================================================================
Compilation Analysis: Compilation Cause
Compilation Analysis:   mark_step in parallel loader at step end
Compilation Analysis: Graph Info: 
Compilation Analysis:   Graph Hash: e58a0059c5ae1ac0035a1d6fcefbc377
Compilation Analysis:   Number of Graph Inputs: 272
Compilation Analysis:   Number of Graph Outputs: 486
Compilation Analysis: Python Frame Triggered Execution: 
Compilation Analysis:   mark_step (/workspaces/dk3/pytorch/xla/torch_xla/core/xla_model.py:1055)
Compilation Analysis:   next (/workspaces/dk3/pytorch/xla/torch_xla/distributed/parallel_loader.py:44)
Compilation Analysis:   __next__ (/workspaces/dk3/pytorch/xla/torch_xla/distributed/parallel_loader.py:32)
Compilation Analysis:   train_loop_fn (/workspaces/dk3/pytorch/xla/examples/train_resnet_base.py:49)
Compilation Analysis:   start_training (/workspaces/dk3/pytorch/xla/examples/train_resnet_base.py:67)
Compilation Analysis:   <module> (/workspaces/dk3/pytorch/xla/examples/train_resnet_base.py:75)
Compilation Analysis: --------------------------------------------------------------------------------
Compilation Analysis: ================================================================================

Post Compilation Analysis: ================================================================================
Post Compilation Analysis: Graph input size: 0.167664 GB
Post Compilation Analysis: Graph output size: 0.199731 GB
Post Compilation Analysis: Aliased Input size: 0.095884 GB
Post Compilation Analysis: Intermediate tensor size: 8.436128 GB
Post Compilation Analysis: Compiled program size: 0.194783 GB
Post Compilation Analysis: --------------------------------------------------------------------------------
Post Compilation Analysis: ================================================================================
epoch: 1, step: 0, loss: 6.9063191413879395, rate: 8.059381469430795
alias count = 
(1, 267.0, ((1717453831.561344, 267.0),))

Note that 0.095884 GB of input is alised, out of 0.167664 GB. and 267/272 of inputs are being aliased. now if I run it again with

XLA_DISABLE_FUNCTIONALIZATION=1 PT_XLA_DEBUG_LEVEL=1 python examples/train_resnet_base.py

I see exactly the same thing.

Post Compilation Analysis: ================================================================================
Post Compilation Analysis: Graph input size: 0.167664 GB
Post Compilation Analysis: Graph output size: 0.192098 GB
Post Compilation Analysis: Aliased Input size: 0.095884 GB
Post Compilation Analysis: Intermediate tensor size: 8.455472 GB
Post Compilation Analysis: Compiled program size: 0.194823 GB
Post Compilation Analysis: --------------------------------------------------------------------------------
Post Compilation Analysis: ================================================================================
epoch: 1, step: 0, loss: 6.912147045135498, rate: 8.488426770395343
alias count = 
(1, 267.0, ((1717454064.4393814, 267.0),))

This suggest on nightly XLA_DISABLE_FUNCTIONALIZATION does not impact aliasing. It is hard to fix the 2.1 branch at this point. I would suggest just run with XLA_DISABLE_FUNCTIONALIZATION=1 until you guys upgraded to 2.3 release.

JackCaoG commented 3 months ago

Ok here is how I will debug this problem. Let's first understand when and how do we do buffer aliasing in 2.1 branch.

https://github.com/pytorch/xla/blob/b7ea06efbf0bb1ae1f094c07db7387470cd787ee/torch_xla/csrc/xla_graph_executor.cpp#L1370-L1371

Build input output alias happens during the compilation, so it is a compile time decision. https://github.com/pytorch/xla/blob/b7ea06efbf0bb1ae1f094c07db7387470cd787ee/torch_xla/csrc/xla_graph_executor.cpp#L1234C19-L1236 It takes tensor that represent outputs (tensors + indices are all XLATensors that will get a new output value after current execution) and lowering context as input.

we first find the tensor ID of all output tensors https://github.com/pytorch/xla/blob/b7ea06efbf0bb1ae1f094c07db7387470cd787ee/torch_xla/csrc/xla_graph_executor.cpp#L1242-L1246

and then we loop through all of the input datas(from the lowering context) https://github.com/pytorch/xla/blob/b7ea06efbf0bb1ae1f094c07db7387470cd787ee/torch_xla/csrc/xla_graph_executor.cpp#L1247-L1249

and see if there is any input data share the same tensorID as the output. If they share the same input ID it means that some in place operation happened and XLATensor no longer attached to this input data so this input data can be aliased to the output with the same shape.

Note we make a couple assumption here

  1. There is only one XLATensor associated with the one XLAData which is usually the case
  2. XLAData's ID correctly reflect the owner of the XLATensor

In order to debug issue above, I think what you want to do is

  1. dump all of output tensor's TensorID
  2. dump all of input data's TensorID

and see why don't they match

jeffhataws commented 3 months ago

Just want to document some findings using the original MLP test using just CPU, printing met.metric_data("InputOutputAliasCount"):

pt2.1 with functionalization on CPU:
(4, 42.0, ((1719204665.609187, 16.0), (1719204667.7811365, 2.0), (1719204668.486205, 12.0), (1719204668.6342065, 12.0))) 

pt2.1 without functionalization on CPU:
(4, 58.0, ((1719204732.9488149, 16.0), (1719204735.0688317, 6.0), (1719204735.781659, 12.0), (1719204735.9386797, 24.0)))

pt2.3 with functionalization on CPU:
(4, 42.0, ((1719204826.4645836, 16.0), (1719204829.2127182, 2.0), (1719204829.8989058, 12.0), (1719204830.0465565, 12.0)))

pt2.3 without functionalization on CPU:
(4, 58.0, ((1719204809.890676, 16.0), (1719204812.663207, 6.0), (1719204813.3367858, 12.0), (1719204813.4833145, 24.0)))  

pt2.4 with functionalization on CPU:
(4, 54.0, ((1719207116.7393954, 16.0), (1719207119.7452862, 2.0), (1719207120.3934603, 12.0), (1719207120.5731502, 24.0)))

pt2.4 without functionalization on CPU:
(4, 58.0, ((1719207140.930515, 16.0), (1719207143.827449, 6.0), (1719207144.4554996, 12.0), (1719207144.6054273, 24.0)))
jeffhataws commented 3 months ago

Minimal reproduction with only 1 linear layer and only gradient accumulation:

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import numpy as np
import copy

dev = xm.xla_device()

train_x_sample = torch.rand((1, 28 * 28))
train_label_sample = torch.tensor([5])

class MLP(nn.Module):
  def __init__(self, input_size = 28 * 28, output_size = 10):
      super(MLP, self).__init__()
      self.fc1 = nn.Linear(input_size, output_size, bias=False)

  def forward(self, x):
      x = self.fc1(x)
      return F.log_softmax(x, dim=1)

c_model = MLP().to('cpu')
t_model = copy.deepcopy(c_model).to(dev)
loss_fn = nn.NLLLoss()

def try_grad_accum(model, device, train_x, train_label, accum_steps):
    train_x = train_x.to(device)
    train_label = train_label.to(device)
    model.zero_grad()
    for i in range(accum_steps):
        output = model(train_x)
        t_loss = loss_fn(output, train_label)
        t_loss.backward()
        xm.mark_step()
    return [p.grad.to('cpu').numpy() for p in model.parameters()]

def test_grad_accum():
    t_model.train()
    c_model.train()
    accum_steps = 4
    c_grads_5 = try_grad_accum(c_model, 'cpu', train_x_sample, train_label_sample, accum_steps)
    t_grads_5 = try_grad_accum(t_model, dev, train_x_sample, train_label_sample, accum_steps)
    np.testing.assert_allclose(t_grads_5, c_grads_5, rtol=3e-2, atol=1e-3)
    print("XLA_DISABLE_FUNCTIONALIZATION:", os.environ["XLA_DISABLE_FUNCTIONALIZATION"],
          " InputOutputAliasCount:", met.metric_data("InputOutputAliasCount"))

if __name__ == '__main__':
    test_grad_accum()

Disable functionalization (see aliasing of gradient buffer):

XLA_FLAGS=--xla_dump_to="./dump_nofunc" XLA_DISABLE_FUNCTIONALIZATION=1 python gradacc_test.py
XLA_DISABLE_FUNCTIONALIZATION: 1  InputOutputAliasCount: (2, 1.0, ((1719295462.695567, 0.0), (1719295462.758401, 1.0)))

dump_nofunc/module_0001.SyncTensorsGraph.115.cpu_after_optimizations.txt also shows aliasing:

HloModule SyncTensorsGraph.115, is_scheduled=true, input_output_alias={ {0}: (3, {}, may-alias) }, entry_computation_layout={(f32[10,784]{1,0}, f32[1,784]{1,0}, s64[1]{0}, f32[10,784]{1,0})->(f32[10,784]{1,0}, f32[1,10]{1,0}, f32[])}

Enable functionalization (see no aliasing):

XLA_FLAGS=--xla_dump_to="./dump_func" XLA_DISABLE_FUNCTIONALIZATION=0 python gradacc_test.py
XLA_DISABLE_FUNCTIONALIZATION: 0  InputOutputAliasCount: (2, 0.0, ((1719295465.6997023, 0.0), (1719295465.7658002, 0.0)))

dump_func/module_0001.SyncTensorsGraph.117.cpu_after_optimizations.txt shows no aliasing:

HloModule SyncTensorsGraph.117, is_scheduled=true, entry_computation_layout={(f32[10,784]{1,0}, f32[1,784]{1,0}, s64[1]{0}, f32[10,784]{1,0})->(f32[1,10]{1,0}, f32[], f32[784,10]{1,0}, f32[10,784]{1,0})}
jeffhataws commented 3 months ago

Using TOT, I modified torch_xla/csrc/xla_graph_executor.cpp to dump some info:

diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp
index 74c3270a9..d31924fb4 100644
--- a/torch_xla/csrc/xla_graph_executor.cpp
+++ b/torch_xla/csrc/xla_graph_executor.cpp
@@ -1264,6 +1264,7 @@ std::vector<size_t> XLAGraphExecutor::SetBufferDonors(
     size_t tensor_index = indices[i];
     int64_t tensor_id = tensors[tensor_index]->data()->alias_id;
     output_tensor_id_map[tensor_id] = i;
+    TF_VLOG(3) << "SetBufferDonors indices_index: " << i << " tensor_index: " << tensor_index << " tensor_id: " << tensor_id;
   }
   const auto& parameters_data = lowering_ctx->GetParametersData();
   std::vector<ssize_t> alias_map(indices.size(), -1);
@@ -1271,8 +1272,10 @@ std::vector<size_t> XLAGraphExecutor::SetBufferDonors(
     auto* data_info =
         static_cast<torch::lazy::LazyGraphExecutor::DeviceDataInfo*>(
             parameters_data[i]->info());
+    TF_VLOG(3) << "SetBufferDonors param index: " << i << " data_info: " << data_info;
     if (data_info != nullptr && !data_info->read_only) {
       auto it = output_tensor_id_map.find(data_info->tensor_id);
+      TF_VLOG(3) << "SetBufferDonors data_info->tensor_id: " << data_info->tensor_id;
       // Parameter buffer's TensorId in output_tensor_id_map means
       // this buffer is not needed after execution since XLATensor will get a
       // new buffer.
@@ -1284,6 +1287,7 @@ std::vector<size_t> XLAGraphExecutor::SetBufferDonors(
     }
   }
   TORCH_LAZY_VALUE_METRIC("InputOutputAliasCount", buffer_donor_indexs.size());
+  TF_VLOG(3) << "SetBufferDonors buffer_donor_indexs: " << buffer_donor_indexs;
   return buffer_donor_indexs;
 }

Disable functionalization (timestamps removed to improve readabiility):

torch_xla/csrc/xla_graph_executor.cpp:1269] SetBufferDonors indices_index: 0 tensor_index: 3 tensor_id: 6
torch_xla/csrc/xla_graph_executor.cpp:1269] SetBufferDonors indices_index: 1 tensor_index: 4 tensor_id: 8
torch_xla/csrc/xla_graph_executor.cpp:1269] SetBufferDonors indices_index: 2 tensor_index: 5 tensor_id: 14
torch_xla/csrc/xla_graph_executor.cpp:1277] SetBufferDonors param index: 0 data_info: 0x5583fe936890
torch_xla/csrc/xla_graph_executor.cpp:1280] SetBufferDonors data_info->tensor_id: 1
torch_xla/csrc/xla_graph_executor.cpp:1277] SetBufferDonors param index: 1 data_info: 0x5583fe968640
torch_xla/csrc/xla_graph_executor.cpp:1280] SetBufferDonors data_info->tensor_id: 2
torch_xla/csrc/xla_graph_executor.cpp:1277] SetBufferDonors param index: 2 data_info: 0x5583fe902550
torch_xla/csrc/xla_graph_executor.cpp:1280] SetBufferDonors data_info->tensor_id: 3
torch_xla/csrc/xla_graph_executor.cpp:1293] SetBufferDonors buffer_donor_indexs:
torch_xla/csrc/xla_graph_executor.cpp:1421] Compiling IR graph hash a7fe2231121157aab2b76c36f9085caf on device CPU:0 done!
torch_xla/csrc/xla_graph_executor.cpp:1132] Executing IR graph hash a7fe2231121157aab2b76c36f9085caf on device CPU:0 done!
torch_xla/csrc/xla_graph_executor.cpp:1269] SetBufferDonors indices_index: 0 tensor_index: 3 tensor_id: 14
torch_xla/csrc/xla_graph_executor.cpp:1269] SetBufferDonors indices_index: 1 tensor_index: 4 tensor_id: 17
torch_xla/csrc/xla_graph_executor.cpp:1269] SetBufferDonors indices_index: 2 tensor_index: 5 tensor_id: 19
torch_xla/csrc/xla_graph_executor.cpp:1277] SetBufferDonors param index: 0 data_info: 0x5583fe936890
torch_xla/csrc/xla_graph_executor.cpp:1280] SetBufferDonors data_info->tensor_id: 1
torch_xla/csrc/xla_graph_executor.cpp:1277] SetBufferDonors param index: 1 data_info: 0x5583fe968640
torch_xla/csrc/xla_graph_executor.cpp:1280] SetBufferDonors data_info->tensor_id: 2
torch_xla/csrc/xla_graph_executor.cpp:1277] SetBufferDonors param index: 2 data_info: 0x5583fe903160
torch_xla/csrc/xla_graph_executor.cpp:1280] SetBufferDonors data_info->tensor_id: 3
torch_xla/csrc/xla_graph_executor.cpp:1277] SetBufferDonors param index: 3 data_info: 0x7f7910019030
torch_xla/csrc/xla_graph_executor.cpp:1280] SetBufferDonors data_info->tensor_id: 14
torch_xla/csrc/xla_graph_executor.cpp:1293] SetBufferDonors buffer_donor_indexs: 3
torch_xla/csrc/xla_graph_executor.cpp:1421] Compiling IR graph hash 8c5d2f3da9d51ab21d4f711702e98d0c on device CPU:0 done!
torch_xla/csrc/xla_graph_executor.cpp:1132] Executing IR graph hash 8c5d2f3da9d51ab21d4f711702e98d0c on device CPU:0 done!
torch_xla/csrc/xla_graph_executor.cpp:1132] Executing IR graph hash 8c5d2f3da9d51ab21d4f711702e98d0c on device CPU:0 done!
torch_xla/csrc/xla_graph_executor.cpp:1132] Executing IR graph hash 8c5d2f3da9d51ab21d4f711702e98d0c on device CPU:0 done!
XLA_DISABLE_FUNCTIONALIZATION: 1  InputOutputAliasCount: (2, 1.0, ((1719438778.8729675, 0.0), (1719438778.9325185, 1.0)))

With functionalization:

torch_xla/csrc/xla_graph_executor.cpp:1269] SetBufferDonors indices_index: 0 tensor_index: 3 tensor_id: 6
torch_xla/csrc/xla_graph_executor.cpp:1269] SetBufferDonors indices_index: 1 tensor_index: 4 tensor_id: 8
torch_xla/csrc/xla_graph_executor.cpp:1269] SetBufferDonors indices_index: 2 tensor_index: 5 tensor_id: 13
torch_xla/csrc/xla_graph_executor.cpp:1269] SetBufferDonors indices_index: 3 tensor_index: 6 tensor_id: 14
torch_xla/csrc/xla_graph_executor.cpp:1277] SetBufferDonors param index: 0 data_info: 0x5556f7ef6d60
torch_xla/csrc/xla_graph_executor.cpp:1280] SetBufferDonors data_info->tensor_id: 1
torch_xla/csrc/xla_graph_executor.cpp:1277] SetBufferDonors param index: 1 data_info: 0x5556f7f5ceb0
torch_xla/csrc/xla_graph_executor.cpp:1280] SetBufferDonors data_info->tensor_id: 2
torch_xla/csrc/xla_graph_executor.cpp:1277] SetBufferDonors param index: 2 data_info: 0x5556f7ef5910
torch_xla/csrc/xla_graph_executor.cpp:1280] SetBufferDonors data_info->tensor_id: 3
torch_xla/csrc/xla_graph_executor.cpp:1293] SetBufferDonors buffer_donor_indexs:
torch_xla/csrc/xla_graph_executor.cpp:1421] Compiling IR graph hash 86eb0fe9e6afb2f4cec87b9250f18010 on device CPU:0 done!
torch_xla/csrc/xla_graph_executor.cpp:1132] Executing IR graph hash 86eb0fe9e6afb2f4cec87b9250f18010 on device CPU:0 done!
torch_xla/csrc/xla_graph_executor.cpp:1269] SetBufferDonors indices_index: 0 tensor_index: 3 tensor_id: 17
torch_xla/csrc/xla_graph_executor.cpp:1269] SetBufferDonors indices_index: 1 tensor_index: 4 tensor_id: 19
torch_xla/csrc/xla_graph_executor.cpp:1269] SetBufferDonors indices_index: 2 tensor_index: 5 tensor_id: 28
torch_xla/csrc/xla_graph_executor.cpp:1269] SetBufferDonors indices_index: 3 tensor_index: 6 tensor_id: 29
torch_xla/csrc/xla_graph_executor.cpp:1277] SetBufferDonors param index: 0 data_info: 0x5556f7ef6d60
torch_xla/csrc/xla_graph_executor.cpp:1280] SetBufferDonors data_info->tensor_id: 1
torch_xla/csrc/xla_graph_executor.cpp:1277] SetBufferDonors param index: 1 data_info: 0x5556f7f4a180
torch_xla/csrc/xla_graph_executor.cpp:1280] SetBufferDonors data_info->tensor_id: 2
torch_xla/csrc/xla_graph_executor.cpp:1277] SetBufferDonors param index: 2 data_info: 0x5556f7f577e0
torch_xla/csrc/xla_graph_executor.cpp:1280] SetBufferDonors data_info->tensor_id: 3
torch_xla/csrc/xla_graph_executor.cpp:1277] SetBufferDonors param index: 3 data_info: 0x7fb9c4018300
torch_xla/csrc/xla_graph_executor.cpp:1280] SetBufferDonors data_info->tensor_id: 14
torch_xla/csrc/xla_graph_executor.cpp:1293] SetBufferDonors buffer_donor_indexs:
torch_xla/csrc/xla_graph_executor.cpp:1421] Compiling IR graph hash 578c2535b5823addcb078d3b876947e1 on device CPU:0 done!
torch_xla/csrc/xla_graph_executor.cpp:1132] Executing IR graph hash 578c2535b5823addcb078d3b876947e1 on device CPU:0 done!
torch_xla/csrc/xla_graph_executor.cpp:1132] Executing IR graph hash 578c2535b5823addcb078d3b876947e1 on device CPU:0 done!
torch_xla/csrc/xla_graph_executor.cpp:1132] Executing IR graph hash 578c2535b5823addcb078d3b876947e1 on device CPU:0 done!
XLA_DISABLE_FUNCTIONALIZATION: 0  InputOutputAliasCount: (2, 0.0, ((1719438794.6676075, 0.0), (1719438794.7328753, 0.0)))
jeffhataws commented 3 months ago

In functionalization graph, there's one additional output of shape (768, 10) that's not there in no-functionalization case.

JackCaoG commented 3 months ago

Ok I was able to repo and I had a rought idea about what's going on. This bug seems to only happen with the graident accumulation.

so we first have the model.fc1.weight.grad after the first step and in the second step PyTorch will use += to accumulate the grad. The tensor ID for model.fc1.weight.grad was 14 before the second step, and during the backward, PyTorch/XLA recevied a call to XLANativeFunctions::_propagate_xla_data which tell us some in place operation happens bettwen tensor ID 14 and and tensor ID 26. PyTorch/XLA learned to link these two tensors together. However after the backward is finished, the tensor ID of the new model.fc1.weight.grad becomes 29. In this case PyTorch/XLA can't tell that tensor ID 14 and tensor ID 29 is the same tensor hence it will not aliased.

@bdhirsh can you point me to where the graident accumulation happens in Pytorch? It happens in C++ so it is a bit hard for me to trace the exact location without doing debug build for both PyTorch and Pytorch/XLA.

My theory is that functionization did something wierd for the graident accumulation which confuse PyTorch/XLA so we can't really tell this is an inplace call.

JackCaoG commented 3 months ago

@jeffhataws what I did was adding

print(torch_xla._XLAC._get_xla_tensor_debug_info(model.fc1.weight.grad))

after the mark_step and add following debug prints

diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp
index 5e233d051..e222231f3 100644
--- a/torch_xla/csrc/aten_xla_type.cpp
+++ b/torch_xla/csrc/aten_xla_type.cpp
@@ -2707,6 +2707,12 @@ void XLANativeFunctions::_propagate_xla_data(const at::Tensor& input,
   // 1) Aid XLA's InputOutputAlias.
   auto input_tensor = bridge::GetXlaTensor(input);
   auto output_tensor = bridge::GetXlaTensor(output);
+  std::cout << "output tensor's alias ID: " << output_tensor->data()->alias_id
+            << "\n";
+  std::cout << "input tensor's tensor ID: " << input_tensor->GetUniqueId()
+            << "\n";
+  std::cout << "input tensor's alias ID: " << input_tensor->data()->alias_id
+            << "\n";
   if (input_tensor->CurrentDataHandle() != nullptr ||
       (input_tensor->CurrentIrValue().node != nullptr &&
        torch_xla::DeviceData::Cast(
@@ -2722,7 +2728,7 @@ void XLANativeFunctions::_propagate_xla_data(const at::Tensor& input,
     xm.mark_step()
     // x.tensor_id =3, x.alias_id should be 2 since input tensor id will be 2
     // for this graph
-    x *= 1 of 1
+    x *= 1
     */
     output_tensor->data()->alias_id = input_tensor->GetUniqueId();
   } else {
diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp
index 0cbb21c44..bd8a9af06 100644
--- a/torch_xla/csrc/xla_graph_executor.cpp
+++ b/torch_xla/csrc/xla_graph_executor.cpp
@@ -1275,6 +1275,8 @@ std::vector<size_t> XLAGraphExecutor::SetBufferDonors(
   for (size_t i = 0; i < indices.size(); ++i) {
     size_t tensor_index = indices[i];
     int64_t tensor_id = tensors[tensor_index]->data()->alias_id;
+    std::cout << "adding alias_id " << tensor_id << " to the map"
+              << "\n";
     output_tensor_id_map[tensor_id] = i;
   }
   const auto& parameters_data = lowering_ctx->GetParametersData();
@@ -1283,6 +1285,7 @@ std::vector<size_t> XLAGraphExecutor::SetBufferDonors(
     auto* data_info =
         static_cast<torch::lazy::LazyGraphExecutor::DeviceDataInfo*>(
             parameters_data[i]->info());
+    std::cout << "parameter tensor id = " << data_info->tensor_id << "\n";
     if (data_info != nullptr && !data_info->read_only) {
       auto it = output_tensor_id_map.find(data_info->tensor_id);
       // Parameter buffer's TensorId in output_tensor_id_map means
bdhirsh commented 3 months ago

Hmm @JackCaoG the grad accumulation impl is here: https://github.com/pytorch/pytorch/blob/main/torch/csrc/autograd/functions/accumulate_grad.h#L116

jeffhataws commented 3 months ago

One additional datapoint is that when I increase the gradient accumulation count, I see that the input tensor ID keeps changing for each gradient accumulation step when functionalization is on. Is that expected?

xla_graph_executor.cpp:1267] SetBufferDonors: adding alias_id 17 to the map to indices_index 0
xla_graph_executor.cpp:1268] SetBufferDonors indices_index: 0 tensor_index: 3 tensor_id(alias_id): 17
xla_graph_executor.cpp:1267] SetBufferDonors: adding alias_id 19 to the map to indices_index 1
xla_graph_executor.cpp:1268] SetBufferDonors indices_index: 1 tensor_index: 4 tensor_id(alias_id): 19    
xla_graph_executor.cpp:1267] SetBufferDonors: adding alias_id 28 to the map to indices_index 2             
xla_graph_executor.cpp:1268] SetBufferDonors indices_index: 2 tensor_index: 5 tensor_id(alias_id): 28
xla_graph_executor.cpp:1267] SetBufferDonors: adding alias_id 29 to the map to indices_index 3
xla_graph_executor.cpp:1268] SetBufferDonors indices_index: 3 tensor_index: 6 tensor_id(alias_id): 29
xla_graph_executor.cpp:1276] SetBufferDonors param index: 0 data_info->tensor_id: 1
xla_graph_executor.cpp:1279] SetBufferDonors data_info->tensor_id: 1
xla_graph_executor.cpp:1276] SetBufferDonors param index: 1 data_info->tensor_id: 2
xla_graph_executor.cpp:1279] SetBufferDonors data_info->tensor_id: 2
xla_graph_executor.cpp:1276] SetBufferDonors param index: 2 data_info->tensor_id: 3
xla_graph_executor.cpp:1279] SetBufferDonors data_info->tensor_id: 3
xla_graph_executor.cpp:1276] SetBufferDonors param index: 3 data_info->tensor_id: 14
xla_graph_executor.cpp:1279] SetBufferDonors data_info->tensor_id: 14
xla_graph_executor.cpp:1291] SetBufferDonors buffer_donor_indexs: 
xla_graph_executor.cpp:1412] Compiling IR graph hash 417805772c05e5e266b5983e46792f74 on device CPU:0 ...
xla_graph_executor.cpp:1419] Compiling IR graph hash 417805772c05e5e266b5983e46792f74 on device CPU:0 done!
xla_graph_executor.cpp:1119] Executing IR graph hash 417805772c05e5e266b5983e46792f74 on device CPU:0 ...
xla_graph_executor.cpp:1130] Executing IR graph hash 417805772c05e5e266b5983e46792f74 on device CPU:0 done!
aten_xla_type.cpp:2714] output tensor's alias ID: 41 shape: f32[10,784]
aten_xla_type.cpp:2715] input tensor's tensor ID: 29
aten_xla_type.cpp:2716] input tensor's alias ID: 29 shape: f32[10,784]
aten_xla_type.cpp:2735] assigning output tensor's alias ID to input tensor's (unique) ID
xla_graph_executor.cpp:1119] Executing IR graph hash 417805772c05e5e266b5983e46792f74 on device CPU:0 ...
xla_graph_executor.cpp:1130] Executing IR graph hash 417805772c05e5e266b5983e46792f74 on device CPU:0 done!
aten_xla_type.cpp:2714] output tensor's alias ID: 56 shape: f32[10,784]
aten_xla_type.cpp:2715] input tensor's tensor ID: 44
aten_xla_type.cpp:2716] input tensor's alias ID: 44 shape: f32[10,784]
aten_xla_type.cpp:2735] assigning output tensor's alias ID to input tensor's (unique) ID
xla_graph_executor.cpp:1119] Executing IR graph hash 417805772c05e5e266b5983e46792f74 on device CPU:0 ...
xla_graph_executor.cpp:1130] Executing IR graph hash 417805772c05e5e266b5983e46792f74 on device CPU:0 done!
aten_xla_type.cpp:2714] output tensor's alias ID: 71 shape: f32[10,784]
aten_xla_type.cpp:2715] input tensor's tensor ID: 59
aten_xla_type.cpp:2716] input tensor's alias ID: 59 shape: f32[10,784]
aten_xla_type.cpp:2735] assigning output tensor's alias ID to input tensor's (unique) ID
xla_graph_executor.cpp:1119] Executing IR graph hash 417805772c05e5e266b5983e46792f74 on device CPU:0 ...
xla_graph_executor.cpp:1130] Executing IR graph hash 417805772c05e5e266b5983e46792f74 on device CPU:0 done!
aten_xla_type.cpp:2714] output tensor's alias ID: 86 shape: f32[10,784]
aten_xla_type.cpp:2715] input tensor's tensor ID: 74
aten_xla_type.cpp:2716] input tensor's alias ID: 74 shape: f32[10,784]
aten_xla_type.cpp:2735] assigning output tensor's alias ID to input tensor's (unique) ID
JackCaoG commented 3 months ago

tensor ID keep changing is expected, with functionalization there is no real in place ops. We can still alias the underlying buffer as long as we know they are supposely the same tensor(through _propagate_xla_data).

jeffhataws commented 3 months ago

Make sense. I think for the second graph, the input tensor_id 14 should be aliased to tensor_id 3, instead of aliased to itself (as indicated by the map in SetBufferDonors: "adding alias_id 14 to the map to indices_index 3"). This way, tensor_id 3 would keep propagating.

JackCaoG commented 3 months ago

I found where tensor ID 27 -29 were from

131│ bool FunctionalStorageImpl::apply_updates() {
132│   // N.B:none of the tensors used in this function should be FunctionalTensorWrappers at this point.
133│   // The only reason we currently need the TLS exclude guard here is because of functorch's DynamicLayer stack.
134│   // It adds the Functionalize key into TLS before redispatching to the functionalization kernels,
135│   // which means that we need to explicitly exclude it here before doing any other work underneath the pass.
136│   at::AutoDispatchSkipFunctionalize guard;
137│   bool any_updates = !updates_.empty();
138│   for (auto& update_data: updates_) {
139├>    base_ = apply_update(update_data, base_);
140│   }
141│   updates_.clear();
142│   return any_updates;
143│ }

will trigger t_copy in pytorch/xla

and it is from

 2252│   at::functionalization::impl::replace_(self, tmp_output);
 2253│   at::functionalization::impl::commit_update(self);
 2254├>  at::functionalization::impl::sync(self);
 2255│     return self;

which is part of

 2204│     at::Tensor & add__Tensor(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) {

in pytorch/build/aten/src/ATen/RegisterFunctionalization_0.cpp.

Here is what happened. Tensor with ID 26 was created with ir add, this is due to the add_ in the accumulated grad. _propagate_xla_data was correctly called and pytorch/xla knows that 26 and 14 was the same tensor. However after the add is finished, at::functionalization::impl::sync triggered a couple permute op and inplace replace the base_ with new tensors. There is no way PyTorch/XLA knows that new permute IR is the same tensor as the add tensor.

@bdhirsh can we also call the _propagate_xla_data for base_ before and after apply_update?

jeffhataws commented 3 months ago

Thanks @JackCaoG for the detailed root-cause. @bdhirsh do you think this can be fix in torch 2.4?

JackCaoG commented 2 months ago

I will give this a try today or tmr.

jeffhataws commented 2 months ago

Hi @JackCaoG , I tried this

diff --git a/torchgen/gen_functionalization_type.py b/torchgen/gen_functionalization_type.py
index 8d2c567c347..2392fe60a1d 100644
--- a/torchgen/gen_functionalization_type.py
+++ b/torchgen/gen_functionalization_type.py
@@ -589,7 +589,9 @@ def wrap_propagate_mutations_and_return(
   at::functionalization::impl::propagate_xla_data({outer_arg}, {inner_ret});
   at::functionalization::impl::replace_({outer_arg}, {inner_ret});
   at::functionalization::impl::commit_update({outer_arg});
-  at::functionalization::impl::sync({outer_arg});"""
+  at::functionalization::impl::sync({outer_arg});
+  at::functionalization::impl::propagate_xla_data({outer_arg}, {inner_ret});
+  """
         )

But I am still seeing empty buffer_donor_indexs:


2024-07-04 23:04:51.386411: I torch_xla/csrc/xla_graph_executor.cpp:1267] SetBufferDonors: adding alias_id 17 to the map to indices_index 0             
2024-07-04 23:04:51.386431: I torch_xla/csrc/xla_graph_executor.cpp:1268] SetBufferDonors indices_index: 0 tensor_index: 3 tensor_id(alias_id): 17
2024-07-04 23:04:51.386439: I torch_xla/csrc/xla_graph_executor.cpp:1267] SetBufferDonors: adding alias_id 19 to the map to indices_index 1
2024-07-04 23:04:51.386447: I torch_xla/csrc/xla_graph_executor.cpp:1268] SetBufferDonors indices_index: 1 tensor_index: 4 tensor_id(alias_id): 19
2024-07-04 23:04:51.386456: I torch_xla/csrc/xla_graph_executor.cpp:1267] SetBufferDonors: adding alias_id 28 to the map to indices_index 2
2024-07-04 23:04:51.386464: I torch_xla/csrc/xla_graph_executor.cpp:1268] SetBufferDonors indices_index: 2 tensor_index: 5 tensor_id(alias_id): 28
2024-07-04 23:04:51.386473: I torch_xla/csrc/xla_graph_executor.cpp:1267] SetBufferDonors: adding alias_id 29 to the map to indices_index 3
2024-07-04 23:04:51.386481: I torch_xla/csrc/xla_graph_executor.cpp:1268] SetBufferDonors indices_index: 3 tensor_index: 6 tensor_id(alias_id): 29
2024-07-04 23:04:51.386491: I torch_xla/csrc/xla_graph_executor.cpp:1276] SetBufferDonors param index: 0 data_info->tensor_id: 1  
2024-07-04 23:04:51.386501: I torch_xla/csrc/xla_graph_executor.cpp:1279] SetBufferDonors data_info->tensor_id: 1                                     
2024-07-04 23:04:51.386510: I torch_xla/csrc/xla_graph_executor.cpp:1276] SetBufferDonors param index: 1 data_info->tensor_id: 2                        
2024-07-04 23:04:51.386519: I torch_xla/csrc/xla_graph_executor.cpp:1279] SetBufferDonors data_info->tensor_id: 2                                     
2024-07-04 23:04:51.386528: I torch_xla/csrc/xla_graph_executor.cpp:1276] SetBufferDonors param index: 2 data_info->tensor_id: 3                        
2024-07-04 23:04:51.386537: I torch_xla/csrc/xla_graph_executor.cpp:1279] SetBufferDonors data_info->tensor_id: 3
2024-07-04 23:04:51.386546: I torch_xla/csrc/xla_graph_executor.cpp:1276] SetBufferDonors param index: 3 data_info->tensor_id: 14
2024-07-04 23:04:51.386555: I torch_xla/csrc/xla_graph_executor.cpp:1279] SetBufferDonors data_info->tensor_id: 14                   
2024-07-04 23:04:51.386564: I torch_xla/csrc/xla_graph_executor.cpp:1291] SetBufferDonors buffer_donor_indexs:      
2024-07-04 23:04:51.386656: I torch_xla/csrc/xla_graph_executor.cpp:1412] Compiling IR graph hash 417805772c05e5e266b5983e46792f74 on device CPU:0 ...
2024-07-04 23:04:51.436769: I torch_xla/csrc/xla_graph_executor.cpp:1419] Compiling IR graph hash 417805772c05e5e266b5983e46792f74 on device CPU:0 done!
2024-07-04 23:04:51.436931: I torch_xla/csrc/xla_graph_executor.cpp:1119] Executing IR graph hash 417805772c05e5e266b5983e46792f74 on device CPU:0 ...
2024-07-04 23:04:51.436995: I torch_xla/csrc/xla_graph_executor.cpp:1130] Executing IR graph hash 417805772c05e5e266b5983e46792f74 on device CPU:0 done!
2024-07-04 23:04:51.437385: I torch_xla/csrc/aten_xla_type.cpp:2714] output tensor's tensor ID: 41
2024-07-04 23:04:51.437407: I torch_xla/csrc/aten_xla_type.cpp:2715] output tensor's alias ID: 41 shape: f32[10,784]
2024-07-04 23:04:51.437414: I torch_xla/csrc/aten_xla_type.cpp:2716] input tensor's tensor ID: 29
2024-07-04 23:04:51.437420: I torch_xla/csrc/aten_xla_type.cpp:2717] input tensor's alias ID: 29 shape: f32[10,784]
2024-07-04 23:04:51.437426: I torch_xla/csrc/aten_xla_type.cpp:2755] assigning output tensor's alias ID to input tensor's (unique) ID
2024-07-04 23:04:51.437460: I torch_xla/csrc/aten_xla_type.cpp:2714] output tensor's tensor ID: 41
2024-07-04 23:04:51.437467: I torch_xla/csrc/aten_xla_type.cpp:2715] output tensor's alias ID: 29 shape: f32[10,784]
2024-07-04 23:04:51.437473: I torch_xla/csrc/aten_xla_type.cpp:2716] input tensor's tensor ID: 44
2024-07-04 23:04:51.437479: I torch_xla/csrc/aten_xla_type.cpp:2717] input tensor's alias ID: 44 shape: f32[10,784]
2024-07-04 23:04:51.437485: I torch_xla/csrc/aten_xla_type.cpp:2769] assigning output tensor's alias ID to input tensor's alias ID
JackCaoG commented 2 months ago

@bdhirsh I gave this a try and did

-  for (auto& update_data: updates_) {
-    base_ = apply_update(update_data, base_);
+  for (auto& update_data : updates_) {
+    auto new_base_ = apply_update(update_data, base_);
+    at::functionalization::impl::propagate_xla_data(base_, new_base_);
+    base_ = new_base_;
   }

propagate_xla_data call will fail with

RuntimeError: isFunctionalTensor(functional_tensor) INTERNAL ASSERT FAILED at "/workspaces/dk3/pytorch/aten/src/ATen/FunctionalTensorWrapper.cpp":710, please report a bug to PyTorch. 

I felt like I am close lol. might shed some light on the right way of doing this?

bdhirsh commented 2 months ago

hmm @JackCaoG , can you try out this patch?

diff --git a/aten/src/ATen/FunctionalStorageImpl.cpp b/aten/src/ATen/FunctionalStorageImpl.cpp
index 3275c8f447f..d788aa29261 100644
--- a/aten/src/ATen/FunctionalStorageImpl.cpp
+++ b/aten/src/ATen/FunctionalStorageImpl.cpp
@@ -6,6 +6,13 @@
 #include <c10/util/Exception.h>
 #include <vector>

+#ifndef AT_PER_OPERATOR_HEADERS
+#include <ATen/Functions.h>
+#else
+#include <ATen/ops/_propagate_xla_data.h>
+#include <ATen/ops/_to_copy.h>
+#endif
+
 namespace at::functionalization {

 ViewMeta ViewMeta::to_out_idx(int64_t out_idx) {
@@ -136,7 +143,11 @@ bool FunctionalStorageImpl::apply_updates() {
   at::AutoDispatchSkipFunctionalize guard;
   bool any_updates = !updates_.empty();
   for (auto& update_data: updates_) {
+    at::Tensor base_old_ = base_;
     base_ = apply_update(update_data, base_);
+    if (base_.key_set().has(c10::DispatchKey::XLA)) {
+      at::_propagate_xla_data(base_old_, base_);
+    }
   }
   updates_.clear();
   return any_updates;

For the patch you posted above: you're calling at::functionalization::impl(a, b), which expects the first tensor to be a functional wrapper and the second to be a normal tensor. You can just bypass that wrapper and call the op directly, with your two inner tensors.

JackCaoG commented 2 months ago

Thanks @bdhirsh I think this partially fixed the problem, I found there is another place I need to fix

442│ void FunctionalTensorWrapper::regenerate_from_base() {
443│   at::AutoDispatchSkipFunctionalize guard;
444│   auto storage_impl = functional_storage_impl();
445│   auto t = storage_impl->base();
446│
447│   TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t));
448├>  t = apply_view_metas(t);
449│   TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t));
450│
451│   replace_(t, /*from_lazy_regenerate=*/true);
452│   generation_ = storage_impl->generation();
453│ }

this is called by

423│ void FunctionalTensorWrapper::sync_() {
424│   if (is_up_to_date()) {
425│     return;
426│   }
427│   apply_updates();
428├>  regenerate_from_base();
429│ }

Hopefully this will fix the whole problem.

JackCaoG commented 2 months ago

ehh this is giving me too much headache lol

we first have tensor ID 13 with shape [784, 10], I believe it is the from the backward of the matmul.

create with at ir value, id 13
ir value [] aten::mm, xla_shape=f32[784,10]{1,0}, dynamic_dims: ()

then we create the real grad with tensor id 14 which is the transpose of the 13

create with at ir value, id 14
ir value [] aten::permute, xla_shape=f32[10,784]{0,1}, dynamic_dims: (), dims=(1, 0)

I check it is from

 9805│ variable_list TBackward0::apply(variable_list&& grads) {
 9806│
 9807│
 9808│   IndexRangeGenerator gen;
 9809│   auto self_ix = gen.range(1);
 9810│   variable_list grad_inputs(gen.size());
 9811│   const auto& grad = grads[0];
 9812│   bool any_grad_defined = any_variable_defined(grads);
 9813│   if (task_should_compute_output({ self_ix })) {
 9814├>    auto grad_result = any_grad_defined ? (grad.t()) : Tensor();
 9815│     copy_range(grad_inputs, self_ix, grad_result);
 9816│   }
 9817│   return grad_inputs;
 9818│ }

and if we print the debug info of the model.fc1.weight.grad it is

XLATensor {
TensorID: 14
Device: TPU:0
XLA Shape: f32[10,784]
ShardingSpec: None
IR: None
XLAData: 
  Data Device: TPU:0
  Data Shape: f32[10,784]
  Data Handle: None
Tensor on host: None
}

now in the second step things become complicated, firstly we do the grad_accum and += create the tensor with ID 26

create with at ir value, id 26
ir value [] aten::add, xla_shape=f32[10,784]{1,0}, dynamic_dims: ()

This tensor is correctly linked to tensor id 14. Until here everything is expected.

However in the FunctionalStorageImpl::apply_updates() I saw

create with at ir value, id 27
ir value [] aten::permute, xla_shape=f32[10,784]{1,0}, dynamic_dims: (), dims=(1, 0)
create with at ir value, id 28
ir value [] aten::permute, xla_shape=f32[784,10]{0,1}, dynamic_dims: (), dims=(1, 0)

and then we tried to _propagate_xla_data but weird enough that is links tensor id 28 [784, 10] with tensor id 13 [784, 10].

It then create tensor id 29

ir value [] aten::permute, xla_shape=f32[10,784]{1,0}, dynamic_dims: (), dims=(1, 0)

and link it to tensor ID 13.

The problem is from the XLA perspective, the input to the graph is the tensor ID 14 (transposed version) instead of 13(non transoposed). I think the issue is that from functionization perspective

JackCaoG commented 2 months ago

ok I think I know what happened.

t14= t13.t()
t14 += temp

t14's base is t13 after the += was applied on t14, sync_ will try to replay the operation from the base(t13) and that's where all of the xla_data starts to get ruined. I think what @jeffhataws did in https://github.com/pytorch/xla/issues/7174#issuecomment-2223677581 is an easier solution to this problem but I think we need to do

at::functionalization::impl::propagate_xla_data({inner_ret}, {outer_arg});

(reverse the inner and outer). Let me try this approach.

JackCaoG commented 2 months ago

ok it failed with

RuntimeError: isFunctionalTensor(functional_tensor) INTERNAL ASSERT FAILED at "/workspaces/dk3/pytorch/aten/src/ATen/FunctionalTensorWrapper.cpp":719, please report a bug to PyTorch. 

but I think it is the right thing to do.

@bdhirsh I think the issue is in

  at::functionalization::impl::propagate_xla_data(self, tmp_output);
  at::functionalization::impl::replace_(self, tmp_output);
  at::functionalization::impl::commit_update(self);
  at::functionalization::impl::sync(self);

self will be repliaced in sync so we want to repropgate_xla_data from tmp_output to new_self (or from old_self to new_self). I will try to save a copy of old_self and tried if that works...

JackCaoG commented 2 months ago

Ok I updated it to

        updates.append(
            f"""\
  auto saved_{outer_arg} = {outer_arg};
  at::functionalization::impl::propagate_xla_data({outer_arg}, {inner_ret});
  at::functionalization::impl::replace_({outer_arg}, {inner_ret});
  at::functionalization::impl::commit_update({outer_arg});
  at::functionalization::impl::sync({outer_arg});
  at::functionalization::impl::propagate_xla_data(saved_{outer_arg}, {outer_arg});"""
        )

and get

          auto saved_self = self;
  at::functionalization::impl::propagate_xla_data(self, tmp_output);
  at::functionalization::impl::replace_(self, tmp_output);
  at::functionalization::impl::commit_update(self);
  at::functionalization::impl::sync(self);
  at::functionalization::impl::propagate_xla_data(saved_self, self);
    return self;

However from the log I saw that at::functionalization::impl::propagate_xla_data(saved_self, self); does not trigger the _propagate_xla_data in XLA land. I will try to get back to this issue later this week

JackCaoG commented 2 months ago

https://github.com/pytorch/pytorch/pull/131076 should fix the issue