Closed jeffhataws closed 2 months ago
so I tired to test it on nightly, with
diff --git a/examples/train_resnet_base.py b/examples/train_resnet_base.py
index b66780d5c..acd00edff 100644
--- a/examples/train_resnet_base.py
+++ b/examples/train_resnet_base.py
@@ -11,6 +11,8 @@ import torch_xla
import torchvision
import torch.optim as optim
import torch.nn as nn
+import torch_xla.debug.metrics as met
+
class TrainResNetBase():
@@ -29,7 +31,7 @@ class TrainResNetBase():
xr.world_size())
self.device = torch_xla.device()
- self.train_device_loader = pl.MpDeviceLoader(train_loader, self.device)
+ self.train_device_loader = pl.MpDeviceLoader(train_loader, self.device, batches_per_execution=4)
self.model = torchvision.models.resnet50().to(self.device)
self.optimizer = optim.SGD(self.model.parameters(), weight_decay=1e-4)
self.loss_fn = nn.CrossEntropyLoss()
@@ -51,6 +53,8 @@ class TrainResNetBase():
loss.backward()
self.run_optimizer()
tracker.add(self.batch_size)
+ print('alias count = ')
+ print(met.metric_data("InputOutputAliasCount"))
if step % 10 == 0:
xm.add_step_closure(
self._train_update, args=(step, loss, tracker, epoch))
we can do graident accumulation for every 4 steps for resnet50. When run with
PT_XLA_DEBUG_LEVEL=1 python examples/train_resnet_base.py
I see
Compilation Analysis: ================================================================================
Compilation Analysis: Compilation Cause
Compilation Analysis: mark_step in parallel loader at step end
Compilation Analysis: Graph Info:
Compilation Analysis: Graph Hash: e58a0059c5ae1ac0035a1d6fcefbc377
Compilation Analysis: Number of Graph Inputs: 272
Compilation Analysis: Number of Graph Outputs: 486
Compilation Analysis: Python Frame Triggered Execution:
Compilation Analysis: mark_step (/workspaces/dk3/pytorch/xla/torch_xla/core/xla_model.py:1055)
Compilation Analysis: next (/workspaces/dk3/pytorch/xla/torch_xla/distributed/parallel_loader.py:44)
Compilation Analysis: __next__ (/workspaces/dk3/pytorch/xla/torch_xla/distributed/parallel_loader.py:32)
Compilation Analysis: train_loop_fn (/workspaces/dk3/pytorch/xla/examples/train_resnet_base.py:49)
Compilation Analysis: start_training (/workspaces/dk3/pytorch/xla/examples/train_resnet_base.py:67)
Compilation Analysis: <module> (/workspaces/dk3/pytorch/xla/examples/train_resnet_base.py:75)
Compilation Analysis: --------------------------------------------------------------------------------
Compilation Analysis: ================================================================================
Post Compilation Analysis: ================================================================================
Post Compilation Analysis: Graph input size: 0.167664 GB
Post Compilation Analysis: Graph output size: 0.199731 GB
Post Compilation Analysis: Aliased Input size: 0.095884 GB
Post Compilation Analysis: Intermediate tensor size: 8.436128 GB
Post Compilation Analysis: Compiled program size: 0.194783 GB
Post Compilation Analysis: --------------------------------------------------------------------------------
Post Compilation Analysis: ================================================================================
epoch: 1, step: 0, loss: 6.9063191413879395, rate: 8.059381469430795
alias count =
(1, 267.0, ((1717453831.561344, 267.0),))
Note that 0.095884 GB
of input is alised, out of 0.167664 GB
. and 267/272 of inputs are being aliased. now if I run it again with
XLA_DISABLE_FUNCTIONALIZATION=1 PT_XLA_DEBUG_LEVEL=1 python examples/train_resnet_base.py
I see exactly the same thing.
Post Compilation Analysis: ================================================================================
Post Compilation Analysis: Graph input size: 0.167664 GB
Post Compilation Analysis: Graph output size: 0.192098 GB
Post Compilation Analysis: Aliased Input size: 0.095884 GB
Post Compilation Analysis: Intermediate tensor size: 8.455472 GB
Post Compilation Analysis: Compiled program size: 0.194823 GB
Post Compilation Analysis: --------------------------------------------------------------------------------
Post Compilation Analysis: ================================================================================
epoch: 1, step: 0, loss: 6.912147045135498, rate: 8.488426770395343
alias count =
(1, 267.0, ((1717454064.4393814, 267.0),))
This suggest on nightly XLA_DISABLE_FUNCTIONALIZATION
does not impact aliasing. It is hard to fix the 2.1 branch at this point. I would suggest just run with XLA_DISABLE_FUNCTIONALIZATION=1
until you guys upgraded to 2.3 release.
Ok here is how I will debug this problem. Let's first understand when and how do we do buffer aliasing in 2.1 branch.
Build input output alias happens during the compilation, so it is a compile time decision. https://github.com/pytorch/xla/blob/b7ea06efbf0bb1ae1f094c07db7387470cd787ee/torch_xla/csrc/xla_graph_executor.cpp#L1234C19-L1236 It takes tensor that represent outputs (tensors + indices are all XLATensors that will get a new output value after current execution) and lowering context as input.
we first find the tensor ID of all output tensors https://github.com/pytorch/xla/blob/b7ea06efbf0bb1ae1f094c07db7387470cd787ee/torch_xla/csrc/xla_graph_executor.cpp#L1242-L1246
and then we loop through all of the input datas(from the lowering context) https://github.com/pytorch/xla/blob/b7ea06efbf0bb1ae1f094c07db7387470cd787ee/torch_xla/csrc/xla_graph_executor.cpp#L1247-L1249
and see if there is any input data share the same tensorID as the output. If they share the same input ID it means that some in place operation happened and XLATensor no longer attached to this input data so this input data can be aliased to the output with the same shape.
Note we make a couple assumption here
In order to debug issue above, I think what you want to do is
and see why don't they match
Just want to document some findings using the original MLP test using just CPU, printing met.metric_data("InputOutputAliasCount")
:
pt2.1 with functionalization on CPU:
(4, 42.0, ((1719204665.609187, 16.0), (1719204667.7811365, 2.0), (1719204668.486205, 12.0), (1719204668.6342065, 12.0)))
pt2.1 without functionalization on CPU:
(4, 58.0, ((1719204732.9488149, 16.0), (1719204735.0688317, 6.0), (1719204735.781659, 12.0), (1719204735.9386797, 24.0)))
pt2.3 with functionalization on CPU:
(4, 42.0, ((1719204826.4645836, 16.0), (1719204829.2127182, 2.0), (1719204829.8989058, 12.0), (1719204830.0465565, 12.0)))
pt2.3 without functionalization on CPU:
(4, 58.0, ((1719204809.890676, 16.0), (1719204812.663207, 6.0), (1719204813.3367858, 12.0), (1719204813.4833145, 24.0)))
pt2.4 with functionalization on CPU:
(4, 54.0, ((1719207116.7393954, 16.0), (1719207119.7452862, 2.0), (1719207120.3934603, 12.0), (1719207120.5731502, 24.0)))
pt2.4 without functionalization on CPU:
(4, 58.0, ((1719207140.930515, 16.0), (1719207143.827449, 6.0), (1719207144.4554996, 12.0), (1719207144.6054273, 24.0)))
Minimal reproduction with only 1 linear layer and only gradient accumulation:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import numpy as np
import copy
dev = xm.xla_device()
train_x_sample = torch.rand((1, 28 * 28))
train_label_sample = torch.tensor([5])
class MLP(nn.Module):
def __init__(self, input_size = 28 * 28, output_size = 10):
super(MLP, self).__init__()
self.fc1 = nn.Linear(input_size, output_size, bias=False)
def forward(self, x):
x = self.fc1(x)
return F.log_softmax(x, dim=1)
c_model = MLP().to('cpu')
t_model = copy.deepcopy(c_model).to(dev)
loss_fn = nn.NLLLoss()
def try_grad_accum(model, device, train_x, train_label, accum_steps):
train_x = train_x.to(device)
train_label = train_label.to(device)
model.zero_grad()
for i in range(accum_steps):
output = model(train_x)
t_loss = loss_fn(output, train_label)
t_loss.backward()
xm.mark_step()
return [p.grad.to('cpu').numpy() for p in model.parameters()]
def test_grad_accum():
t_model.train()
c_model.train()
accum_steps = 4
c_grads_5 = try_grad_accum(c_model, 'cpu', train_x_sample, train_label_sample, accum_steps)
t_grads_5 = try_grad_accum(t_model, dev, train_x_sample, train_label_sample, accum_steps)
np.testing.assert_allclose(t_grads_5, c_grads_5, rtol=3e-2, atol=1e-3)
print("XLA_DISABLE_FUNCTIONALIZATION:", os.environ["XLA_DISABLE_FUNCTIONALIZATION"],
" InputOutputAliasCount:", met.metric_data("InputOutputAliasCount"))
if __name__ == '__main__':
test_grad_accum()
Disable functionalization (see aliasing of gradient buffer):
XLA_FLAGS=--xla_dump_to="./dump_nofunc" XLA_DISABLE_FUNCTIONALIZATION=1 python gradacc_test.py
XLA_DISABLE_FUNCTIONALIZATION: 1 InputOutputAliasCount: (2, 1.0, ((1719295462.695567, 0.0), (1719295462.758401, 1.0)))
dump_nofunc/module_0001.SyncTensorsGraph.115.cpu_after_optimizations.txt also shows aliasing:
HloModule SyncTensorsGraph.115, is_scheduled=true, input_output_alias={ {0}: (3, {}, may-alias) }, entry_computation_layout={(f32[10,784]{1,0}, f32[1,784]{1,0}, s64[1]{0}, f32[10,784]{1,0})->(f32[10,784]{1,0}, f32[1,10]{1,0}, f32[])}
Enable functionalization (see no aliasing):
XLA_FLAGS=--xla_dump_to="./dump_func" XLA_DISABLE_FUNCTIONALIZATION=0 python gradacc_test.py
XLA_DISABLE_FUNCTIONALIZATION: 0 InputOutputAliasCount: (2, 0.0, ((1719295465.6997023, 0.0), (1719295465.7658002, 0.0)))
dump_func/module_0001.SyncTensorsGraph.117.cpu_after_optimizations.txt shows no aliasing:
HloModule SyncTensorsGraph.117, is_scheduled=true, entry_computation_layout={(f32[10,784]{1,0}, f32[1,784]{1,0}, s64[1]{0}, f32[10,784]{1,0})->(f32[1,10]{1,0}, f32[], f32[784,10]{1,0}, f32[10,784]{1,0})}
Using TOT, I modified torch_xla/csrc/xla_graph_executor.cpp to dump some info:
diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp
index 74c3270a9..d31924fb4 100644
--- a/torch_xla/csrc/xla_graph_executor.cpp
+++ b/torch_xla/csrc/xla_graph_executor.cpp
@@ -1264,6 +1264,7 @@ std::vector<size_t> XLAGraphExecutor::SetBufferDonors(
size_t tensor_index = indices[i];
int64_t tensor_id = tensors[tensor_index]->data()->alias_id;
output_tensor_id_map[tensor_id] = i;
+ TF_VLOG(3) << "SetBufferDonors indices_index: " << i << " tensor_index: " << tensor_index << " tensor_id: " << tensor_id;
}
const auto& parameters_data = lowering_ctx->GetParametersData();
std::vector<ssize_t> alias_map(indices.size(), -1);
@@ -1271,8 +1272,10 @@ std::vector<size_t> XLAGraphExecutor::SetBufferDonors(
auto* data_info =
static_cast<torch::lazy::LazyGraphExecutor::DeviceDataInfo*>(
parameters_data[i]->info());
+ TF_VLOG(3) << "SetBufferDonors param index: " << i << " data_info: " << data_info;
if (data_info != nullptr && !data_info->read_only) {
auto it = output_tensor_id_map.find(data_info->tensor_id);
+ TF_VLOG(3) << "SetBufferDonors data_info->tensor_id: " << data_info->tensor_id;
// Parameter buffer's TensorId in output_tensor_id_map means
// this buffer is not needed after execution since XLATensor will get a
// new buffer.
@@ -1284,6 +1287,7 @@ std::vector<size_t> XLAGraphExecutor::SetBufferDonors(
}
}
TORCH_LAZY_VALUE_METRIC("InputOutputAliasCount", buffer_donor_indexs.size());
+ TF_VLOG(3) << "SetBufferDonors buffer_donor_indexs: " << buffer_donor_indexs;
return buffer_donor_indexs;
}
Disable functionalization (timestamps removed to improve readabiility):
torch_xla/csrc/xla_graph_executor.cpp:1269] SetBufferDonors indices_index: 0 tensor_index: 3 tensor_id: 6
torch_xla/csrc/xla_graph_executor.cpp:1269] SetBufferDonors indices_index: 1 tensor_index: 4 tensor_id: 8
torch_xla/csrc/xla_graph_executor.cpp:1269] SetBufferDonors indices_index: 2 tensor_index: 5 tensor_id: 14
torch_xla/csrc/xla_graph_executor.cpp:1277] SetBufferDonors param index: 0 data_info: 0x5583fe936890
torch_xla/csrc/xla_graph_executor.cpp:1280] SetBufferDonors data_info->tensor_id: 1
torch_xla/csrc/xla_graph_executor.cpp:1277] SetBufferDonors param index: 1 data_info: 0x5583fe968640
torch_xla/csrc/xla_graph_executor.cpp:1280] SetBufferDonors data_info->tensor_id: 2
torch_xla/csrc/xla_graph_executor.cpp:1277] SetBufferDonors param index: 2 data_info: 0x5583fe902550
torch_xla/csrc/xla_graph_executor.cpp:1280] SetBufferDonors data_info->tensor_id: 3
torch_xla/csrc/xla_graph_executor.cpp:1293] SetBufferDonors buffer_donor_indexs:
torch_xla/csrc/xla_graph_executor.cpp:1421] Compiling IR graph hash a7fe2231121157aab2b76c36f9085caf on device CPU:0 done!
torch_xla/csrc/xla_graph_executor.cpp:1132] Executing IR graph hash a7fe2231121157aab2b76c36f9085caf on device CPU:0 done!
torch_xla/csrc/xla_graph_executor.cpp:1269] SetBufferDonors indices_index: 0 tensor_index: 3 tensor_id: 14
torch_xla/csrc/xla_graph_executor.cpp:1269] SetBufferDonors indices_index: 1 tensor_index: 4 tensor_id: 17
torch_xla/csrc/xla_graph_executor.cpp:1269] SetBufferDonors indices_index: 2 tensor_index: 5 tensor_id: 19
torch_xla/csrc/xla_graph_executor.cpp:1277] SetBufferDonors param index: 0 data_info: 0x5583fe936890
torch_xla/csrc/xla_graph_executor.cpp:1280] SetBufferDonors data_info->tensor_id: 1
torch_xla/csrc/xla_graph_executor.cpp:1277] SetBufferDonors param index: 1 data_info: 0x5583fe968640
torch_xla/csrc/xla_graph_executor.cpp:1280] SetBufferDonors data_info->tensor_id: 2
torch_xla/csrc/xla_graph_executor.cpp:1277] SetBufferDonors param index: 2 data_info: 0x5583fe903160
torch_xla/csrc/xla_graph_executor.cpp:1280] SetBufferDonors data_info->tensor_id: 3
torch_xla/csrc/xla_graph_executor.cpp:1277] SetBufferDonors param index: 3 data_info: 0x7f7910019030
torch_xla/csrc/xla_graph_executor.cpp:1280] SetBufferDonors data_info->tensor_id: 14
torch_xla/csrc/xla_graph_executor.cpp:1293] SetBufferDonors buffer_donor_indexs: 3
torch_xla/csrc/xla_graph_executor.cpp:1421] Compiling IR graph hash 8c5d2f3da9d51ab21d4f711702e98d0c on device CPU:0 done!
torch_xla/csrc/xla_graph_executor.cpp:1132] Executing IR graph hash 8c5d2f3da9d51ab21d4f711702e98d0c on device CPU:0 done!
torch_xla/csrc/xla_graph_executor.cpp:1132] Executing IR graph hash 8c5d2f3da9d51ab21d4f711702e98d0c on device CPU:0 done!
torch_xla/csrc/xla_graph_executor.cpp:1132] Executing IR graph hash 8c5d2f3da9d51ab21d4f711702e98d0c on device CPU:0 done!
XLA_DISABLE_FUNCTIONALIZATION: 1 InputOutputAliasCount: (2, 1.0, ((1719438778.8729675, 0.0), (1719438778.9325185, 1.0)))
With functionalization:
torch_xla/csrc/xla_graph_executor.cpp:1269] SetBufferDonors indices_index: 0 tensor_index: 3 tensor_id: 6
torch_xla/csrc/xla_graph_executor.cpp:1269] SetBufferDonors indices_index: 1 tensor_index: 4 tensor_id: 8
torch_xla/csrc/xla_graph_executor.cpp:1269] SetBufferDonors indices_index: 2 tensor_index: 5 tensor_id: 13
torch_xla/csrc/xla_graph_executor.cpp:1269] SetBufferDonors indices_index: 3 tensor_index: 6 tensor_id: 14
torch_xla/csrc/xla_graph_executor.cpp:1277] SetBufferDonors param index: 0 data_info: 0x5556f7ef6d60
torch_xla/csrc/xla_graph_executor.cpp:1280] SetBufferDonors data_info->tensor_id: 1
torch_xla/csrc/xla_graph_executor.cpp:1277] SetBufferDonors param index: 1 data_info: 0x5556f7f5ceb0
torch_xla/csrc/xla_graph_executor.cpp:1280] SetBufferDonors data_info->tensor_id: 2
torch_xla/csrc/xla_graph_executor.cpp:1277] SetBufferDonors param index: 2 data_info: 0x5556f7ef5910
torch_xla/csrc/xla_graph_executor.cpp:1280] SetBufferDonors data_info->tensor_id: 3
torch_xla/csrc/xla_graph_executor.cpp:1293] SetBufferDonors buffer_donor_indexs:
torch_xla/csrc/xla_graph_executor.cpp:1421] Compiling IR graph hash 86eb0fe9e6afb2f4cec87b9250f18010 on device CPU:0 done!
torch_xla/csrc/xla_graph_executor.cpp:1132] Executing IR graph hash 86eb0fe9e6afb2f4cec87b9250f18010 on device CPU:0 done!
torch_xla/csrc/xla_graph_executor.cpp:1269] SetBufferDonors indices_index: 0 tensor_index: 3 tensor_id: 17
torch_xla/csrc/xla_graph_executor.cpp:1269] SetBufferDonors indices_index: 1 tensor_index: 4 tensor_id: 19
torch_xla/csrc/xla_graph_executor.cpp:1269] SetBufferDonors indices_index: 2 tensor_index: 5 tensor_id: 28
torch_xla/csrc/xla_graph_executor.cpp:1269] SetBufferDonors indices_index: 3 tensor_index: 6 tensor_id: 29
torch_xla/csrc/xla_graph_executor.cpp:1277] SetBufferDonors param index: 0 data_info: 0x5556f7ef6d60
torch_xla/csrc/xla_graph_executor.cpp:1280] SetBufferDonors data_info->tensor_id: 1
torch_xla/csrc/xla_graph_executor.cpp:1277] SetBufferDonors param index: 1 data_info: 0x5556f7f4a180
torch_xla/csrc/xla_graph_executor.cpp:1280] SetBufferDonors data_info->tensor_id: 2
torch_xla/csrc/xla_graph_executor.cpp:1277] SetBufferDonors param index: 2 data_info: 0x5556f7f577e0
torch_xla/csrc/xla_graph_executor.cpp:1280] SetBufferDonors data_info->tensor_id: 3
torch_xla/csrc/xla_graph_executor.cpp:1277] SetBufferDonors param index: 3 data_info: 0x7fb9c4018300
torch_xla/csrc/xla_graph_executor.cpp:1280] SetBufferDonors data_info->tensor_id: 14
torch_xla/csrc/xla_graph_executor.cpp:1293] SetBufferDonors buffer_donor_indexs:
torch_xla/csrc/xla_graph_executor.cpp:1421] Compiling IR graph hash 578c2535b5823addcb078d3b876947e1 on device CPU:0 done!
torch_xla/csrc/xla_graph_executor.cpp:1132] Executing IR graph hash 578c2535b5823addcb078d3b876947e1 on device CPU:0 done!
torch_xla/csrc/xla_graph_executor.cpp:1132] Executing IR graph hash 578c2535b5823addcb078d3b876947e1 on device CPU:0 done!
torch_xla/csrc/xla_graph_executor.cpp:1132] Executing IR graph hash 578c2535b5823addcb078d3b876947e1 on device CPU:0 done!
XLA_DISABLE_FUNCTIONALIZATION: 0 InputOutputAliasCount: (2, 0.0, ((1719438794.6676075, 0.0), (1719438794.7328753, 0.0)))
In functionalization graph, there's one additional output of shape (768, 10) that's not there in no-functionalization case.
Ok I was able to repo and I had a rought idea about what's going on. This bug seems to only happen with the graident accumulation.
so we first have the model.fc1.weight.grad
after the first step and in the second step PyTorch will use +=
to accumulate the grad. The tensor ID for model.fc1.weight.grad
was 14
before the second step, and during the backward, PyTorch/XLA recevied a call to XLANativeFunctions::_propagate_xla_data
which tell us some in place operation happens bettwen tensor ID 14
and and tensor ID 26
. PyTorch/XLA learned to link these two tensors together. However after the backward is finished, the tensor ID of the new model.fc1.weight.grad
becomes 29
. In this case PyTorch/XLA can't tell that tensor ID 14
and tensor ID 29
is the same tensor hence it will not aliased.
@bdhirsh can you point me to where the graident accumulation happens in Pytorch? It happens in C++ so it is a bit hard for me to trace the exact location without doing debug build for both PyTorch and Pytorch/XLA.
My theory is that functionization did something wierd for the graident accumulation which confuse PyTorch/XLA so we can't really tell this is an inplace call.
@jeffhataws what I did was adding
print(torch_xla._XLAC._get_xla_tensor_debug_info(model.fc1.weight.grad))
after the mark_step
and add following debug prints
diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp
index 5e233d051..e222231f3 100644
--- a/torch_xla/csrc/aten_xla_type.cpp
+++ b/torch_xla/csrc/aten_xla_type.cpp
@@ -2707,6 +2707,12 @@ void XLANativeFunctions::_propagate_xla_data(const at::Tensor& input,
// 1) Aid XLA's InputOutputAlias.
auto input_tensor = bridge::GetXlaTensor(input);
auto output_tensor = bridge::GetXlaTensor(output);
+ std::cout << "output tensor's alias ID: " << output_tensor->data()->alias_id
+ << "\n";
+ std::cout << "input tensor's tensor ID: " << input_tensor->GetUniqueId()
+ << "\n";
+ std::cout << "input tensor's alias ID: " << input_tensor->data()->alias_id
+ << "\n";
if (input_tensor->CurrentDataHandle() != nullptr ||
(input_tensor->CurrentIrValue().node != nullptr &&
torch_xla::DeviceData::Cast(
@@ -2722,7 +2728,7 @@ void XLANativeFunctions::_propagate_xla_data(const at::Tensor& input,
xm.mark_step()
// x.tensor_id =3, x.alias_id should be 2 since input tensor id will be 2
// for this graph
- x *= 1 of 1
+ x *= 1
*/
output_tensor->data()->alias_id = input_tensor->GetUniqueId();
} else {
diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp
index 0cbb21c44..bd8a9af06 100644
--- a/torch_xla/csrc/xla_graph_executor.cpp
+++ b/torch_xla/csrc/xla_graph_executor.cpp
@@ -1275,6 +1275,8 @@ std::vector<size_t> XLAGraphExecutor::SetBufferDonors(
for (size_t i = 0; i < indices.size(); ++i) {
size_t tensor_index = indices[i];
int64_t tensor_id = tensors[tensor_index]->data()->alias_id;
+ std::cout << "adding alias_id " << tensor_id << " to the map"
+ << "\n";
output_tensor_id_map[tensor_id] = i;
}
const auto& parameters_data = lowering_ctx->GetParametersData();
@@ -1283,6 +1285,7 @@ std::vector<size_t> XLAGraphExecutor::SetBufferDonors(
auto* data_info =
static_cast<torch::lazy::LazyGraphExecutor::DeviceDataInfo*>(
parameters_data[i]->info());
+ std::cout << "parameter tensor id = " << data_info->tensor_id << "\n";
if (data_info != nullptr && !data_info->read_only) {
auto it = output_tensor_id_map.find(data_info->tensor_id);
// Parameter buffer's TensorId in output_tensor_id_map means
Hmm @JackCaoG the grad accumulation impl is here: https://github.com/pytorch/pytorch/blob/main/torch/csrc/autograd/functions/accumulate_grad.h#L116
One additional datapoint is that when I increase the gradient accumulation count, I see that the input tensor ID keeps changing for each gradient accumulation step when functionalization is on. Is that expected?
xla_graph_executor.cpp:1267] SetBufferDonors: adding alias_id 17 to the map to indices_index 0
xla_graph_executor.cpp:1268] SetBufferDonors indices_index: 0 tensor_index: 3 tensor_id(alias_id): 17
xla_graph_executor.cpp:1267] SetBufferDonors: adding alias_id 19 to the map to indices_index 1
xla_graph_executor.cpp:1268] SetBufferDonors indices_index: 1 tensor_index: 4 tensor_id(alias_id): 19
xla_graph_executor.cpp:1267] SetBufferDonors: adding alias_id 28 to the map to indices_index 2
xla_graph_executor.cpp:1268] SetBufferDonors indices_index: 2 tensor_index: 5 tensor_id(alias_id): 28
xla_graph_executor.cpp:1267] SetBufferDonors: adding alias_id 29 to the map to indices_index 3
xla_graph_executor.cpp:1268] SetBufferDonors indices_index: 3 tensor_index: 6 tensor_id(alias_id): 29
xla_graph_executor.cpp:1276] SetBufferDonors param index: 0 data_info->tensor_id: 1
xla_graph_executor.cpp:1279] SetBufferDonors data_info->tensor_id: 1
xla_graph_executor.cpp:1276] SetBufferDonors param index: 1 data_info->tensor_id: 2
xla_graph_executor.cpp:1279] SetBufferDonors data_info->tensor_id: 2
xla_graph_executor.cpp:1276] SetBufferDonors param index: 2 data_info->tensor_id: 3
xla_graph_executor.cpp:1279] SetBufferDonors data_info->tensor_id: 3
xla_graph_executor.cpp:1276] SetBufferDonors param index: 3 data_info->tensor_id: 14
xla_graph_executor.cpp:1279] SetBufferDonors data_info->tensor_id: 14
xla_graph_executor.cpp:1291] SetBufferDonors buffer_donor_indexs:
xla_graph_executor.cpp:1412] Compiling IR graph hash 417805772c05e5e266b5983e46792f74 on device CPU:0 ...
xla_graph_executor.cpp:1419] Compiling IR graph hash 417805772c05e5e266b5983e46792f74 on device CPU:0 done!
xla_graph_executor.cpp:1119] Executing IR graph hash 417805772c05e5e266b5983e46792f74 on device CPU:0 ...
xla_graph_executor.cpp:1130] Executing IR graph hash 417805772c05e5e266b5983e46792f74 on device CPU:0 done!
aten_xla_type.cpp:2714] output tensor's alias ID: 41 shape: f32[10,784]
aten_xla_type.cpp:2715] input tensor's tensor ID: 29
aten_xla_type.cpp:2716] input tensor's alias ID: 29 shape: f32[10,784]
aten_xla_type.cpp:2735] assigning output tensor's alias ID to input tensor's (unique) ID
xla_graph_executor.cpp:1119] Executing IR graph hash 417805772c05e5e266b5983e46792f74 on device CPU:0 ...
xla_graph_executor.cpp:1130] Executing IR graph hash 417805772c05e5e266b5983e46792f74 on device CPU:0 done!
aten_xla_type.cpp:2714] output tensor's alias ID: 56 shape: f32[10,784]
aten_xla_type.cpp:2715] input tensor's tensor ID: 44
aten_xla_type.cpp:2716] input tensor's alias ID: 44 shape: f32[10,784]
aten_xla_type.cpp:2735] assigning output tensor's alias ID to input tensor's (unique) ID
xla_graph_executor.cpp:1119] Executing IR graph hash 417805772c05e5e266b5983e46792f74 on device CPU:0 ...
xla_graph_executor.cpp:1130] Executing IR graph hash 417805772c05e5e266b5983e46792f74 on device CPU:0 done!
aten_xla_type.cpp:2714] output tensor's alias ID: 71 shape: f32[10,784]
aten_xla_type.cpp:2715] input tensor's tensor ID: 59
aten_xla_type.cpp:2716] input tensor's alias ID: 59 shape: f32[10,784]
aten_xla_type.cpp:2735] assigning output tensor's alias ID to input tensor's (unique) ID
xla_graph_executor.cpp:1119] Executing IR graph hash 417805772c05e5e266b5983e46792f74 on device CPU:0 ...
xla_graph_executor.cpp:1130] Executing IR graph hash 417805772c05e5e266b5983e46792f74 on device CPU:0 done!
aten_xla_type.cpp:2714] output tensor's alias ID: 86 shape: f32[10,784]
aten_xla_type.cpp:2715] input tensor's tensor ID: 74
aten_xla_type.cpp:2716] input tensor's alias ID: 74 shape: f32[10,784]
aten_xla_type.cpp:2735] assigning output tensor's alias ID to input tensor's (unique) ID
tensor ID keep changing is expected, with functionalization there is no real in place ops. We can still alias the underlying buffer as long as we know they are supposely the same tensor(through _propagate_xla_data
).
Make sense. I think for the second graph, the input tensor_id 14 should be aliased to tensor_id 3, instead of aliased to itself (as indicated by the map in SetBufferDonors: "adding alias_id 14 to the map to indices_index 3"). This way, tensor_id 3 would keep propagating.
I found where tensor ID 27 -29 were from
131│ bool FunctionalStorageImpl::apply_updates() {
132│ // N.B:none of the tensors used in this function should be FunctionalTensorWrappers at this point.
133│ // The only reason we currently need the TLS exclude guard here is because of functorch's DynamicLayer stack.
134│ // It adds the Functionalize key into TLS before redispatching to the functionalization kernels,
135│ // which means that we need to explicitly exclude it here before doing any other work underneath the pass.
136│ at::AutoDispatchSkipFunctionalize guard;
137│ bool any_updates = !updates_.empty();
138│ for (auto& update_data: updates_) {
139├> base_ = apply_update(update_data, base_);
140│ }
141│ updates_.clear();
142│ return any_updates;
143│ }
will trigger t_copy
in pytorch/xla
and it is from
2252│ at::functionalization::impl::replace_(self, tmp_output);
2253│ at::functionalization::impl::commit_update(self);
2254├> at::functionalization::impl::sync(self);
2255│ return self;
which is part of
2204│ at::Tensor & add__Tensor(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) {
in pytorch/build/aten/src/ATen/RegisterFunctionalization_0.cpp
.
Here is what happened. Tensor with ID 26 was created with ir add
, this is due to the add_
in the accumulated grad. _propagate_xla_data
was correctly called and pytorch/xla knows that 26
and 14
was the same tensor. However after the add
is finished, at::functionalization::impl::sync
triggered a couple permute
op and inplace replace the base_
with new tensors. There is no way PyTorch/XLA knows that new permute
IR is the same tensor as the add
tensor.
@bdhirsh can we also call the _propagate_xla_data
for base_
before and after apply_update
?
Thanks @JackCaoG for the detailed root-cause. @bdhirsh do you think this can be fix in torch 2.4?
I will give this a try today or tmr.
Hi @JackCaoG , I tried this
diff --git a/torchgen/gen_functionalization_type.py b/torchgen/gen_functionalization_type.py
index 8d2c567c347..2392fe60a1d 100644
--- a/torchgen/gen_functionalization_type.py
+++ b/torchgen/gen_functionalization_type.py
@@ -589,7 +589,9 @@ def wrap_propagate_mutations_and_return(
at::functionalization::impl::propagate_xla_data({outer_arg}, {inner_ret});
at::functionalization::impl::replace_({outer_arg}, {inner_ret});
at::functionalization::impl::commit_update({outer_arg});
- at::functionalization::impl::sync({outer_arg});"""
+ at::functionalization::impl::sync({outer_arg});
+ at::functionalization::impl::propagate_xla_data({outer_arg}, {inner_ret});
+ """
)
But I am still seeing empty buffer_donor_indexs:
2024-07-04 23:04:51.386411: I torch_xla/csrc/xla_graph_executor.cpp:1267] SetBufferDonors: adding alias_id 17 to the map to indices_index 0
2024-07-04 23:04:51.386431: I torch_xla/csrc/xla_graph_executor.cpp:1268] SetBufferDonors indices_index: 0 tensor_index: 3 tensor_id(alias_id): 17
2024-07-04 23:04:51.386439: I torch_xla/csrc/xla_graph_executor.cpp:1267] SetBufferDonors: adding alias_id 19 to the map to indices_index 1
2024-07-04 23:04:51.386447: I torch_xla/csrc/xla_graph_executor.cpp:1268] SetBufferDonors indices_index: 1 tensor_index: 4 tensor_id(alias_id): 19
2024-07-04 23:04:51.386456: I torch_xla/csrc/xla_graph_executor.cpp:1267] SetBufferDonors: adding alias_id 28 to the map to indices_index 2
2024-07-04 23:04:51.386464: I torch_xla/csrc/xla_graph_executor.cpp:1268] SetBufferDonors indices_index: 2 tensor_index: 5 tensor_id(alias_id): 28
2024-07-04 23:04:51.386473: I torch_xla/csrc/xla_graph_executor.cpp:1267] SetBufferDonors: adding alias_id 29 to the map to indices_index 3
2024-07-04 23:04:51.386481: I torch_xla/csrc/xla_graph_executor.cpp:1268] SetBufferDonors indices_index: 3 tensor_index: 6 tensor_id(alias_id): 29
2024-07-04 23:04:51.386491: I torch_xla/csrc/xla_graph_executor.cpp:1276] SetBufferDonors param index: 0 data_info->tensor_id: 1
2024-07-04 23:04:51.386501: I torch_xla/csrc/xla_graph_executor.cpp:1279] SetBufferDonors data_info->tensor_id: 1
2024-07-04 23:04:51.386510: I torch_xla/csrc/xla_graph_executor.cpp:1276] SetBufferDonors param index: 1 data_info->tensor_id: 2
2024-07-04 23:04:51.386519: I torch_xla/csrc/xla_graph_executor.cpp:1279] SetBufferDonors data_info->tensor_id: 2
2024-07-04 23:04:51.386528: I torch_xla/csrc/xla_graph_executor.cpp:1276] SetBufferDonors param index: 2 data_info->tensor_id: 3
2024-07-04 23:04:51.386537: I torch_xla/csrc/xla_graph_executor.cpp:1279] SetBufferDonors data_info->tensor_id: 3
2024-07-04 23:04:51.386546: I torch_xla/csrc/xla_graph_executor.cpp:1276] SetBufferDonors param index: 3 data_info->tensor_id: 14
2024-07-04 23:04:51.386555: I torch_xla/csrc/xla_graph_executor.cpp:1279] SetBufferDonors data_info->tensor_id: 14
2024-07-04 23:04:51.386564: I torch_xla/csrc/xla_graph_executor.cpp:1291] SetBufferDonors buffer_donor_indexs:
2024-07-04 23:04:51.386656: I torch_xla/csrc/xla_graph_executor.cpp:1412] Compiling IR graph hash 417805772c05e5e266b5983e46792f74 on device CPU:0 ...
2024-07-04 23:04:51.436769: I torch_xla/csrc/xla_graph_executor.cpp:1419] Compiling IR graph hash 417805772c05e5e266b5983e46792f74 on device CPU:0 done!
2024-07-04 23:04:51.436931: I torch_xla/csrc/xla_graph_executor.cpp:1119] Executing IR graph hash 417805772c05e5e266b5983e46792f74 on device CPU:0 ...
2024-07-04 23:04:51.436995: I torch_xla/csrc/xla_graph_executor.cpp:1130] Executing IR graph hash 417805772c05e5e266b5983e46792f74 on device CPU:0 done!
2024-07-04 23:04:51.437385: I torch_xla/csrc/aten_xla_type.cpp:2714] output tensor's tensor ID: 41
2024-07-04 23:04:51.437407: I torch_xla/csrc/aten_xla_type.cpp:2715] output tensor's alias ID: 41 shape: f32[10,784]
2024-07-04 23:04:51.437414: I torch_xla/csrc/aten_xla_type.cpp:2716] input tensor's tensor ID: 29
2024-07-04 23:04:51.437420: I torch_xla/csrc/aten_xla_type.cpp:2717] input tensor's alias ID: 29 shape: f32[10,784]
2024-07-04 23:04:51.437426: I torch_xla/csrc/aten_xla_type.cpp:2755] assigning output tensor's alias ID to input tensor's (unique) ID
2024-07-04 23:04:51.437460: I torch_xla/csrc/aten_xla_type.cpp:2714] output tensor's tensor ID: 41
2024-07-04 23:04:51.437467: I torch_xla/csrc/aten_xla_type.cpp:2715] output tensor's alias ID: 29 shape: f32[10,784]
2024-07-04 23:04:51.437473: I torch_xla/csrc/aten_xla_type.cpp:2716] input tensor's tensor ID: 44
2024-07-04 23:04:51.437479: I torch_xla/csrc/aten_xla_type.cpp:2717] input tensor's alias ID: 44 shape: f32[10,784]
2024-07-04 23:04:51.437485: I torch_xla/csrc/aten_xla_type.cpp:2769] assigning output tensor's alias ID to input tensor's alias ID
@bdhirsh I gave this a try and did
- for (auto& update_data: updates_) {
- base_ = apply_update(update_data, base_);
+ for (auto& update_data : updates_) {
+ auto new_base_ = apply_update(update_data, base_);
+ at::functionalization::impl::propagate_xla_data(base_, new_base_);
+ base_ = new_base_;
}
propagate_xla_data
call will fail with
RuntimeError: isFunctionalTensor(functional_tensor) INTERNAL ASSERT FAILED at "/workspaces/dk3/pytorch/aten/src/ATen/FunctionalTensorWrapper.cpp":710, please report a bug to PyTorch.
I felt like I am close lol. might shed some light on the right way of doing this?
hmm @JackCaoG , can you try out this patch?
diff --git a/aten/src/ATen/FunctionalStorageImpl.cpp b/aten/src/ATen/FunctionalStorageImpl.cpp
index 3275c8f447f..d788aa29261 100644
--- a/aten/src/ATen/FunctionalStorageImpl.cpp
+++ b/aten/src/ATen/FunctionalStorageImpl.cpp
@@ -6,6 +6,13 @@
#include <c10/util/Exception.h>
#include <vector>
+#ifndef AT_PER_OPERATOR_HEADERS
+#include <ATen/Functions.h>
+#else
+#include <ATen/ops/_propagate_xla_data.h>
+#include <ATen/ops/_to_copy.h>
+#endif
+
namespace at::functionalization {
ViewMeta ViewMeta::to_out_idx(int64_t out_idx) {
@@ -136,7 +143,11 @@ bool FunctionalStorageImpl::apply_updates() {
at::AutoDispatchSkipFunctionalize guard;
bool any_updates = !updates_.empty();
for (auto& update_data: updates_) {
+ at::Tensor base_old_ = base_;
base_ = apply_update(update_data, base_);
+ if (base_.key_set().has(c10::DispatchKey::XLA)) {
+ at::_propagate_xla_data(base_old_, base_);
+ }
}
updates_.clear();
return any_updates;
For the patch you posted above: you're calling at::functionalization::impl(a, b)
, which expects the first tensor to be a functional wrapper and the second to be a normal tensor. You can just bypass that wrapper and call the op directly, with your two inner tensors.
Thanks @bdhirsh I think this partially fixed the problem, I found there is another place I need to fix
442│ void FunctionalTensorWrapper::regenerate_from_base() {
443│ at::AutoDispatchSkipFunctionalize guard;
444│ auto storage_impl = functional_storage_impl();
445│ auto t = storage_impl->base();
446│
447│ TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t));
448├> t = apply_view_metas(t);
449│ TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t));
450│
451│ replace_(t, /*from_lazy_regenerate=*/true);
452│ generation_ = storage_impl->generation();
453│ }
this is called by
423│ void FunctionalTensorWrapper::sync_() {
424│ if (is_up_to_date()) {
425│ return;
426│ }
427│ apply_updates();
428├> regenerate_from_base();
429│ }
Hopefully this will fix the whole problem.
ehh this is giving me too much headache lol
we first have tensor ID 13 with shape [784, 10], I believe it is the from the backward of the matmul.
create with at ir value, id 13
ir value [] aten::mm, xla_shape=f32[784,10]{1,0}, dynamic_dims: ()
then we create the real grad with tensor id 14 which is the transpose of the 13
create with at ir value, id 14
ir value [] aten::permute, xla_shape=f32[10,784]{0,1}, dynamic_dims: (), dims=(1, 0)
I check it is from
9805│ variable_list TBackward0::apply(variable_list&& grads) {
9806│
9807│
9808│ IndexRangeGenerator gen;
9809│ auto self_ix = gen.range(1);
9810│ variable_list grad_inputs(gen.size());
9811│ const auto& grad = grads[0];
9812│ bool any_grad_defined = any_variable_defined(grads);
9813│ if (task_should_compute_output({ self_ix })) {
9814├> auto grad_result = any_grad_defined ? (grad.t()) : Tensor();
9815│ copy_range(grad_inputs, self_ix, grad_result);
9816│ }
9817│ return grad_inputs;
9818│ }
and if we print the debug info of the model.fc1.weight.grad
it is
XLATensor {
TensorID: 14
Device: TPU:0
XLA Shape: f32[10,784]
ShardingSpec: None
IR: None
XLAData:
Data Device: TPU:0
Data Shape: f32[10,784]
Data Handle: None
Tensor on host: None
}
now in the second step things become complicated, firstly we do the grad_accum
and +=
create the tensor with ID 26
create with at ir value, id 26
ir value [] aten::add, xla_shape=f32[10,784]{1,0}, dynamic_dims: ()
This tensor is correctly linked to tensor id 14
. Until here everything is expected.
However in the FunctionalStorageImpl::apply_updates()
I saw
create with at ir value, id 27
ir value [] aten::permute, xla_shape=f32[10,784]{1,0}, dynamic_dims: (), dims=(1, 0)
create with at ir value, id 28
ir value [] aten::permute, xla_shape=f32[784,10]{0,1}, dynamic_dims: (), dims=(1, 0)
and then we tried to _propagate_xla_data
but weird enough that is links tensor id 28 [784, 10]
with tensor id 13 [784, 10]
.
It then create tensor id 29
ir value [] aten::permute, xla_shape=f32[10,784]{1,0}, dynamic_dims: (), dims=(1, 0)
and link it to tensor ID 13.
The problem is from the XLA perspective, the input to the graph is the tensor ID 14
(transposed version) instead of 13
(non transoposed). I think the issue is that from functionization perspective
ok I think I know what happened.
t14= t13.t()
t14 += temp
t14
's base is t13
after the +=
was applied on t14
, sync_
will try to replay the operation from the base(t13
) and that's where all of the xla_data
starts to get ruined. I think what @jeffhataws did in https://github.com/pytorch/xla/issues/7174#issuecomment-2223677581 is an easier solution to this problem but I think we need to do
at::functionalization::impl::propagate_xla_data({inner_ret}, {outer_arg});
(reverse the inner and outer). Let me try this approach.
ok it failed with
RuntimeError: isFunctionalTensor(functional_tensor) INTERNAL ASSERT FAILED at "/workspaces/dk3/pytorch/aten/src/ATen/FunctionalTensorWrapper.cpp":719, please report a bug to PyTorch.
but I think it is the right thing to do.
@bdhirsh I think the issue is in
at::functionalization::impl::propagate_xla_data(self, tmp_output);
at::functionalization::impl::replace_(self, tmp_output);
at::functionalization::impl::commit_update(self);
at::functionalization::impl::sync(self);
self
will be repliaced in sync
so we want to repropgate_xla_data
from tmp_output
to new_self (or from old_self
to new_self
). I will try to save a copy of old_self
and tried if that works...
Ok I updated it to
updates.append(
f"""\
auto saved_{outer_arg} = {outer_arg};
at::functionalization::impl::propagate_xla_data({outer_arg}, {inner_ret});
at::functionalization::impl::replace_({outer_arg}, {inner_ret});
at::functionalization::impl::commit_update({outer_arg});
at::functionalization::impl::sync({outer_arg});
at::functionalization::impl::propagate_xla_data(saved_{outer_arg}, {outer_arg});"""
)
and get
auto saved_self = self;
at::functionalization::impl::propagate_xla_data(self, tmp_output);
at::functionalization::impl::replace_(self, tmp_output);
at::functionalization::impl::commit_update(self);
at::functionalization::impl::sync(self);
at::functionalization::impl::propagate_xla_data(saved_self, self);
return self;
However from the log I saw that at::functionalization::impl::propagate_xla_data(saved_self, self);
does not trigger the _propagate_xla_data
in XLA land. I will try to get back to this issue later this week
https://github.com/pytorch/pytorch/pull/131076 should fix the issue
🐛 Bug
When functionalization is on (XLA_DISABLE_FUNCTIONALIZATION=0), I see that there are fewer aliased tensors. Jack has a patch to increase the number of aliased tensors https://github.com/pytorch/xla/commit/e3fc03314dab5f44e3ed9ccbba6c15fbca3285cd . However, even though this change helped increase the number of aliased tensor, it seems to still missing aliasing for gradients when gradient accumulation is used.
Using test_train_mp_mnist.py, make the modifications below. I added a mark_step to isolate the gradient accumulation loops.
I only see 2 alias even though we expect all the gradient tensors to be aliased:
To Reproduce
Steps to reproduce the behavior:
Expected behavior
Expect gradients to be aliased
Environment
Additional context