pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.38k stars 427 forks source link

Dynamo persistent cache real-time look-up #7614

Open wonjoolee95 opened 1 week ago

wonjoolee95 commented 1 week ago

🚀 Feature

As described in https://github.com/pytorch/pytorch/issues/125958, we are integrating with vLLM on TPUs. We see that in the warm up phase of the vLLM, it needs to pre-compile ~30 different input shape combinations. PyTorch/XLA does not support dynamic shapes today so torch.compile will keep compiling the model code which slows down the development speed (waiting for 10 minutes before warm up is finished). PyTorch/XLA already cache the XLA compilation but torch.compile itself is pretty expensive.

This feature request pitches to achieve the similar effect of dynamic shapes by persistent caching and real time look up of the compiled program.

Details

To do this, in high-level, we need to do the following:

Open questions

cc @JackCaoG @WoosukKwon

JackCaoG commented 1 week ago

couple thing you need to check

  1. whether parameters are always passed as input to the opt_mode function. It used to be that only the data passed but dynamo expect backend to remember the parameters of the model. We did a mapping in https://github.com/pytorch/xla/blob/7ec577847a1f29f91e88e05bcab9e9f2c6f5cbad/torch_xla/core/dynamo_bridge.py#L495. If parameters are always passed then this is easier, we can just do a hashing on the parameter shapes and use that as the key to find the corresponding hashs for the compiled program.
  2. Is the hash for the compiled binary the only thing we need to switch between different shapes. do we need a copy of https://github.com/pytorch/xla/blob/7ec577847a1f29f91e88e05bcab9e9f2c6f5cbad/torch_xla/core/dynamo_bridge.py#L433-L435 for every compiled binary? If so how do we obey it?

I think you can start play with it by turn on the dynamic mode(it is disabled for XLA somewhere in upstream I think) and see what kind of fx graphs and inputs we get.

wonjoolee95 commented 6 days ago

Thanks for the comment, Jack. Trying to dig around for dynamic mode for XLA, it seems like we disabled a while back -- https://github.com/pytorch/xla/pull/5285. Turning it back to True, I start to see shape of the tensor in xla_args:

def fn(a: torch.Tensor, b: torch.Tensor):
    c = a + b
    return c

compiled_fn = torch.compile(fn, backend='openxla')
for x, y in [[1,2], [3,4], [5,6], [1,2]]:
    a = torch.randn(x, y, device=device)
    b = torch.randn(y, device=device)
    ret = compiled_fn(a, b)
    xm.mark_step()

##### output
# iteration=1, shape=(1,2)
[WONJOO] xla_args=(tensor([[ 1.8011, -0.6563]], device='xla:0'), tensor([-1.1532, -1.2312], device='xla:0'))
[Testing] ret=tensor([[ 0.6479, -1.8875]], device='xla:0')
# iteration=2, shape(3,4)
[WONJOO] xla_args=(3, 4, tensor([[ 0.4419,  0.4539,  0.3953,  0.0427],
        [-0.3349, -0.4918, -0.2272, -0.6963],
        [ 0.1970,  0.2022,  1.1043, -0.6868]], device='xla:0'), tensor([0.6107, 0.2966, 0.1040, 0.5142], device='xla:0'))
Traceback (most recent call last):
  File "/home/wonjoo/debug/dynamo_dynamic_shape.py", line 19, in <module>
    ret = compiled_fn(a, b)
  File "/home/wonjoo/pytorch/torch/_dynamo/eval_frame.py", line 433, in _fn
    return fn(*args, **kwargs)
  File "/home/wonjoo/debug/dynamo_dynamic_shape.py", line 5, in fn
    def fn(a: torch.Tensor, b: torch.Tensor):
  File "/home/wonjoo/pytorch/torch/_dynamo/eval_frame.py", line 600, in _fn
    return fn(*args, **kwargs)
  File "/home/wonjoo/pytorch/torch/_functorch/aot_autograd.py", line 986, in forward
    return compiled_fn(full_args)
  File "/home/wonjoo/pytorch/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 222, in runtime_wrapper
    all_outs = call_func_at_runtime_with_args(
  File "/home/wonjoo/pytorch/torch/_functorch/_aot_autograd/utils.py", line 120, in call_func_at_runtime_with_args
    out = normalize_as_list(f(args))
  File "/home/wonjoo/pytorch/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 456, in wrapper
    return compiled_fn(runtime_args)
  File "/home/wonjoo/pytorch/torch/_functorch/_aot_autograd/utils.py", line 94, in g
    return f(*args)
  File "/home/wonjoo/pytorch/torch/_dynamo/backends/torchxla.py", line 36, in fwd
    compiled_graph = bridge.extract_compiled_graph(model, args)
  File "/home/wonjoo/pytorch/xla/torch_xla/core/dynamo_bridge.py", line 635, in extract_compiled_graph
    if xla_arg.device.type != 'xla':
AttributeError: 'int' object has no attribute 'device'
I0000 00:00:1720047025.375931 1331731 cpu_client.cc:481] TfrtCpuClient destroyed

I assume the first two inputs to the xla_args represent the shape (3, 4), so we can ignore the error for now. A couple of questions come to my mind:

Let me play around with some more examples..