pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.44k stars 452 forks source link

Is there any way to directly execute the cached computational graph #7682

Open mars1248 opened 1 month ago

mars1248 commented 1 month ago

❓ Questions and Help

My application code is complex, but it's not computationally expensive, and the graph is consistent, so I tried to cache it with XLA_PERSISTENT_CACHE_PATH, but it took a long time to execute the logic (without performing any computation).Is there any way to execute the cached graph? I also tried dynamo, but encountered many errors, such as incompatibility with autocast and so on

JackCaoG commented 1 month ago

technically there is, you can look at our dynamo implementation where we

  1. execute the tracing https://github.com/pytorch/xla/blob/08e63e32af9eee71e8cd13d672f3200ee3356ab4/torch_xla/core/dynamo_bridge.py#L337
  2. compute the hash + warm up the cache(compilation) https://github.com/pytorch/xla/blob/08e63e32af9eee71e8cd13d672f3200ee3356ab4/torch_xla/core/dynamo_bridge.py#L395-L399
  3. execute the hash with input https://github.com/pytorch/xla/blob/08e63e32af9eee71e8cd13d672f3200ee3356ab4/torch_xla/core/dynamo_bridge.py#L497

Dynamo is suppose to do what you expected, it handles the input ordering, output ordering, functionization of the graph etc. If you uses these api directly you need to be very careful.

mars1248 commented 1 month ago

Thank you very much for your answer. I have successfully run the forward calculation of the model according to your tips and referring to this ut, https://github.com/pytorch/xla/blob/08e63e32af9eee71e8cd13d672f3200ee3356ab4/test/dynamo/test_graph_input_matcher.py but I do not know how to add the backward calculation and the optimizer state update?

JackCaoG commented 1 month ago

technically you can do

loss = fwd(input)
loss.backward()
optimizer.step()
graph_hash = torch_xla._XLAC._get_graph_hash([loss] + [all_parameter_gradient]) 

From xla perspective there is not fwd and bwd, you just need to pass all of the output(int this case gradients) it will use those as root to construct the whole graph.

mars1248 commented 1 month ago

Thank you very much for your reply.I went through the whole process, but I found that the parameters were not updated, resulting in the same loss(in my case is res[0]). I constructed a minimal single test that can reproduce this problem. @JackCaoG


import torch
import torch_xla
import torch_xla.core.xla_model as xm
from torch import nn
from torch.utils._pytree import tree_map_only
from torch_xla.core.dynamo_bridge import GraphInputMatcher
from torch_xla.amp import syncfree

class M(nn.Module):

  def __init__(self):
    super().__init__()
    self.linear = nn.Linear(5, 3)

  def forward(self, x):
    return self.linear(x)

  def get_example_inputs(self):
    return (torch.rand(10, 5),)

xla_dev = xm.xla_device()
model = M().to(device=xla_dev)
optimizer = syncfree.AdamW(model.parameters(), lr=0.01)
inputs = tree_map_only(torch.Tensor, lambda x: x.to(device=xla_dev),
                        model.get_example_inputs())

xm.mark_step()
args_tensor_ids = [
    torch_xla._XLAC._xla_get_tensor_id(xla_arg) for xla_arg in inputs
]
tensor_id_to_arg_idx = {
    tensor_id: i for i, tensor_id in enumerate(args_tensor_ids)
}
output = model(*inputs).sum()
output.backward()
found_inf = torch.isnan(output).to(torch.float32).to(xla_dev)
optimizer.step(found_inf=found_inf)
opt_state = []
for name, p in model.named_parameters():
    if p.grad is not None:
        opt_state.append(p)
        opt_state.append(p.grad)
    else:
        print(name, "no grad")
output_list = [output] + opt_state
xla_graph_hash = torch_xla._XLAC._get_graph_hash(output_list)
torch_xla._XLAC._xla_warm_up_cache(output_list, [])
(
    graph_input_tensor_ids,
    graph_input_xla_values,
) = torch_xla._XLAC._get_tensors_xla_device_data_node(output_list)
xla_args_tensor_ids = set(
    tree_map_only(torch.Tensor,
                    lambda input: torch_xla._XLAC._xla_get_tensor_id(input),
                    inputs))
graph_input_matcher = GraphInputMatcher(tensor_id_to_arg_idx,
                                        graph_input_tensor_ids,
                                        graph_input_xla_values,
                                        xla_args_tensor_ids)
for i in range(3):
    graph_input = graph_input_matcher(inputs)
    res = torch_xla._XLAC._run_cached_graph(xla_graph_hash, graph_input)
    print(res[0])

I think the code below is logically the same as the code above, but the loss will change for the code below, but not for the code above

import torch
import torch_xla
import torch_xla.core.xla_model as xm
from torch import nn
from torch.utils._pytree import tree_map_only
from torch_xla.core.dynamo_bridge import GraphInputMatcher
from torch_xla.amp import syncfree

class M(nn.Module):

  def __init__(self):
    super().__init__()
    self.linear = nn.Linear(5, 3)

  def forward(self, x):
    return self.linear(x)

  def get_example_inputs(self):
    return (torch.rand(10, 5),)

xla_dev = xm.xla_device()
model = M().to(device=xla_dev)
optimizer = syncfree.AdamW(model.parameters(), lr=0.01)
inputs = tree_map_only(torch.Tensor, lambda x: x.to(device=xla_dev),
                        model.get_example_inputs())

xm.mark_step()
args_tensor_ids = [
    torch_xla._XLAC._xla_get_tensor_id(xla_arg) for xla_arg in inputs
]
tensor_id_to_arg_idx = {
    tensor_id: i for i, tensor_id in enumerate(args_tensor_ids)
}
for i in range(3):
    output = model(*inputs).sum()
    output.backward()
    found_inf = torch.isnan(output).to(torch.float32).to(xla_dev)
    optimizer.step(found_inf=found_inf)
    optimizer.zero_grad()
    xm.mark_step()
    print("debug ", output)
mars1248 commented 1 month ago

@JackCaoG @dewitt @sprt @ezyang Hello, I have located the root cause of the problem in the above single test, because only the placeholder was assigned, but the parameters were not assigned, so although new_param was calculated, the parameters did not change. https://github.com/pytorch/xla/blob/master/torch_xla/csrc/xla_graph_executor.cpp#L816-L817 But I don't know how to fix this problem. Could you give me some ideas?

JackCaoG commented 1 month ago

I will try to take a look tmr.

mars1248 commented 1 month ago

@lsy323 @JackCaoG
Thank you for your attention. I have located the problem, which is caused by the weight of not being updated. I also found that there is some handling of this situation. https://github.com/pytorch/xla/blob/master/torch_xla/core/dynamo_bridge.py#L588-L589, but it seems to me that the handling is a bit complicated, I wrote a maybe simpler logic. I used the _get_tensors_xla_data_node_hash function to get the backend ptr of all the param that needed to be updated. We also modify the _get_tensors_xla_device_data_node function to determine which input exists based on the backend ptr, and finally modify _run_cached_graph to assign the result of the calculation to the corresponding parameters. If your think my idea is feasible, I can push a pull requries to torch_xla. The main thing is that I have a lot of problems with using fx to trace graphs, as mentioned when this issue was first raised. That's why I built my own solution. below is a demo test:


import torch
import torch_xla
import torch_xla.core.xla_model as xm
from torch import nn
from torch.utils._pytree import tree_map_only
from torch_xla.core.dynamo_bridge import GraphInputMatcher
from torch_xla.amp import syncfree

class M(nn.Module):

  def __init__(self):
    super().__init__()
    self.linear = nn.Linear(5, 3)

  def forward(self, x):
    return self.linear(x)

  def get_example_inputs(self):
    return (torch.rand(10, 5),)
model = M()

demo_input = model.get_example_inputs()
xla_dev = xm.xla_device()
model = model.to(device=xla_dev)

plist = []

optimizer = syncfree.AdamW(model.parameters(), lr=0.01)
for group in optimizer.param_groups:
    for p in group['params']:
        if not p.requires_grad:
            continue
        state = optimizer.state[p]
        if not state:
            state['step'] = torch.tensor(0.0).to(xm.xla_device())
            state['exp_avg'] = torch.zeros_like(p).to(xm.xla_device())
            state['exp_avg_sq'] = torch.zeros_like(p).to(xm.xla_device())
            state['max_exp_avg_sq'] = torch.empty(
                                      0, dtype=torch.float, device=xm.xla_device())
        plist.append(p)
        #plist.append(state['step'])
        plist.append(state['exp_avg'])
        plist.append(state['exp_avg_sq'])
        #plist.append(state['max_exp_avg_sq'])

ans = torch_xla._XLAC._get_tensors_xla_data_node_hash(plist)

inputs = tree_map_only(torch.Tensor, lambda x: x.to(device=xla_dev),
                        demo_input)

args_tensor_ids = [
    torch_xla._XLAC._xla_get_tensor_id(xla_arg) for xla_arg in inputs
]
tensor_id_to_arg_idx = {
    tensor_id: i for i, tensor_id in enumerate(args_tensor_ids)
}
output = model(*inputs).sum()
output.backward()
found_inf = torch.isnan(output).to(torch.float32).to(xla_dev)
optimizer.step(found_inf=found_inf)
optimizer.zero_grad(set_to_none=False)
opt_state = []
for name, p in model.named_parameters():
    if p.grad is not None:
        opt_state.append(p)
        #opt_state.append(optimizer.state[p]["step"])
        opt_state.append(optimizer.state[p]["exp_avg"])
        opt_state.append(optimizer.state[p]["exp_avg_sq"])
        # opt_state.append(optimizer.state[p]["max_exp_avg_sq"])
        # opt_state.append(p.grad)
    else:
        print(name, "no grad")
output_list = [output] + opt_state

xla_graph_hash = torch_xla._XLAC._get_graph_hash(output_list)
torch_xla._XLAC._xla_warm_up_cache(output_list, [])
(
    graph_input_tensor_ids,
    graph_input_xla_values,
) = torch_xla._XLAC._get_tensors_xla_device_data_node(output_list, ans)
index = len(graph_input_tensor_ids) - len(plist)
output_list_ans = graph_input_tensor_ids[index :]
graph_input_tensor_ids = graph_input_tensor_ids[: index]
graph_input_xla_values = graph_input_xla_values[: index]
xla_args_tensor_ids = set(
    tree_map_only(torch.Tensor,
                    lambda input: torch_xla._XLAC._xla_get_tensor_id(input),
                    inputs))
graph_input_matcher = GraphInputMatcher(tensor_id_to_arg_idx,
                                        graph_input_tensor_ids,
                                        graph_input_xla_values,
                                        xla_args_tensor_ids)

for i in range(3):
    graph_input = graph_input_matcher(inputs)
    res = torch_xla._XLAC._run_cached_graph(xla_graph_hash, graph_input, output_list_ans)
    print(res)
JackCaoG commented 1 month ago

Haha we started at some simpler code like yours but need to keep adding logic to handle edge cases and more features. If you refactored version works great for you use case I think keeping that is a good idea.