Open wonjoolee95 opened 1 week ago
couple thing you need to check
opt_mode
function. It used to be that only the data passed but dynamo expect backend to remember the parameters of the model. We did a mapping in https://github.com/pytorch/xla/blob/7ec577847a1f29f91e88e05bcab9e9f2c6f5cbad/torch_xla/core/dynamo_bridge.py#L495. If parameters are always passed then this is easier, we can just do a hashing on the parameter shapes and use that as the key to find the corresponding hashs for the compiled program.I think you can start play with it by turn on the dynamic mode(it is disabled for XLA somewhere in upstream I think) and see what kind of fx graphs and inputs we get.
Thanks for the comment, Jack. Trying to dig around for dynamic mode for XLA, it seems like we disabled a while back -- https://github.com/pytorch/xla/pull/5285. Turning it back to True, I start to see shape of the tensor in xla_args:
def fn(a: torch.Tensor, b: torch.Tensor):
c = a + b
return c
compiled_fn = torch.compile(fn, backend='openxla')
for x, y in [[1,2], [3,4], [5,6], [1,2]]:
a = torch.randn(x, y, device=device)
b = torch.randn(y, device=device)
ret = compiled_fn(a, b)
xm.mark_step()
##### output
# iteration=1, shape=(1,2)
[WONJOO] xla_args=(tensor([[ 1.8011, -0.6563]], device='xla:0'), tensor([-1.1532, -1.2312], device='xla:0'))
[Testing] ret=tensor([[ 0.6479, -1.8875]], device='xla:0')
# iteration=2, shape(3,4)
[WONJOO] xla_args=(3, 4, tensor([[ 0.4419, 0.4539, 0.3953, 0.0427],
[-0.3349, -0.4918, -0.2272, -0.6963],
[ 0.1970, 0.2022, 1.1043, -0.6868]], device='xla:0'), tensor([0.6107, 0.2966, 0.1040, 0.5142], device='xla:0'))
Traceback (most recent call last):
File "/home/wonjoo/debug/dynamo_dynamic_shape.py", line 19, in <module>
ret = compiled_fn(a, b)
File "/home/wonjoo/pytorch/torch/_dynamo/eval_frame.py", line 433, in _fn
return fn(*args, **kwargs)
File "/home/wonjoo/debug/dynamo_dynamic_shape.py", line 5, in fn
def fn(a: torch.Tensor, b: torch.Tensor):
File "/home/wonjoo/pytorch/torch/_dynamo/eval_frame.py", line 600, in _fn
return fn(*args, **kwargs)
File "/home/wonjoo/pytorch/torch/_functorch/aot_autograd.py", line 986, in forward
return compiled_fn(full_args)
File "/home/wonjoo/pytorch/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 222, in runtime_wrapper
all_outs = call_func_at_runtime_with_args(
File "/home/wonjoo/pytorch/torch/_functorch/_aot_autograd/utils.py", line 120, in call_func_at_runtime_with_args
out = normalize_as_list(f(args))
File "/home/wonjoo/pytorch/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 456, in wrapper
return compiled_fn(runtime_args)
File "/home/wonjoo/pytorch/torch/_functorch/_aot_autograd/utils.py", line 94, in g
return f(*args)
File "/home/wonjoo/pytorch/torch/_dynamo/backends/torchxla.py", line 36, in fwd
compiled_graph = bridge.extract_compiled_graph(model, args)
File "/home/wonjoo/pytorch/xla/torch_xla/core/dynamo_bridge.py", line 635, in extract_compiled_graph
if xla_arg.device.type != 'xla':
AttributeError: 'int' object has no attribute 'device'
I0000 00:00:1720047025.375931 1331731 cpu_client.cc:481] TfrtCpuClient destroyed
I assume the first two inputs to the xla_args
represent the shape (3, 4)
, so we can ignore the error for now. A couple of questions come to my mind:
(1,2)
included in the first iteration?(3,4)
represents the shape for the first tensor. Where is the shape (4)
that represents the shape for the second tensor?Let me play around with some more examples..
🚀 Feature
As described in https://github.com/pytorch/pytorch/issues/125958, we are integrating with vLLM on TPUs. We see that in the warm up phase of the vLLM, it needs to pre-compile ~30 different input shape combinations. PyTorch/XLA does not support dynamic shapes today so torch.compile will keep compiling the model code which slows down the development speed (waiting for 10 minutes before warm up is finished). PyTorch/XLA already cache the XLA compilation but torch.compile itself is pretty expensive.
This feature request pitches to achieve the similar effect of dynamic shapes by persistent caching and real time look up of the compiled program.
Details
To do this, in high-level, we need to do the following:
Open questions
cc @JackCaoG @WoosukKwon