pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
83.56k stars 22.54k forks source link

torch.compile + ring attention #121386

Closed yanboliang closed 1 month ago

yanboliang commented 8 months ago

🐛 Describe the bug

I was trying to enable ring attention with torch.compile, here are the issues that I encountered:

from user code: File "/data/users/ybliang/debug/empathy/ring-attention-pytorch/ring_attention_pytorch/ring_attention.py", line 206, in torch_dynamo_resume_in_sharded_batch_to_sharded_seq_at_194 x = rearrange('(b s) n -> b (s n)', x, s = num_sharded_batches)

* einx rearrange issue

Traceback (most recent call last): File "/home/ybliang/local/pytorch/torch/multiprocessing/spawn.py", line 75, in _wrap fn(i, args) File "/data/users/ybliang/debug/empathy/ring-attention-pytorch/assert.py", line 85, in start ring_out = ddp_ring_attention_net(seq) File "/home/ybliang/local/pytorch/torch/nn/modules/module.py", line 1527, in _wrapped_call_impl return self._call_impl(args, kwargs) File "/home/ybliang/local/pytorch/torch/nn/modules/module.py", line 1536, in _call_impl return forward_call(*args, *kwargs) File "/home/ybliang/local/pytorch/torch/_dynamo/eval_frame.py", line 437, in _fn return fn(args, kwargs) File "/home/ybliang/local/pytorch/torch/_dynamo/external_utils.py", line 36, in inner return fn(*args, kwargs) File "/home/ybliang/local/pytorch/torch/nn/modules/module.py", line 1527, in _wrapped_call_impl return self._call_impl(*args, *kwargs) File "/home/ybliang/local/pytorch/torch/nn/modules/module.py", line 1536, in _call_impl return forward_call(args, kwargs) File "/home/ybliang/local/pytorch/torch/nn/parallel/distributed.py", line 1593, in forward else self._run_ddp_forward(*inputs, kwargs) File "/home/ybliang/local/pytorch/torch/nn/parallel/distributed.py", line 1411, in _run_ddp_forward return self.module(*inputs, *kwargs) # type: ignore[index] File "/home/ybliang/local/pytorch/torch/nn/modules/module.py", line 1527, in _wrapped_call_impl return self._call_impl(args, kwargs) File "/home/ybliang/local/pytorch/torch/nn/modules/module.py", line 1536, in _call_impl return forward_call(*args, kwargs) File "/data/users/ybliang/debug/empathy/ring-attention-pytorch/ring_attention_pytorch/ring_attention.py", line 548, in forward (x, mask), batch_sizes, num_sharded_batches = sharded_batch_to_sharded_seq(x, mask, self.ring_seq_size) File "/data/users/ybliang/debug/empathy/ring-attention-pytorch/ring_attention_pytorch/ring_attention.py", line 600, in torch_dynamo_resume_in_forward_at_548 logits = rearrange('b (i j) d -> b (j i) d', logits, j = self.bucket_size) File "/home/ybliang/local/miniconda3/envs/pt/lib/python3.10/site-packages/einx/lru_cache.py", line 66, in inner backend = einx.backend.get(input_tracer_values) File "/home/ybliang/local/miniconda3/envs/pt/lib/python3.10/site-packages/einx/lru_cache.py", line 70, in torch_dynamo_resume_in_inner_at_66 graph = construct_graph(*args, backend=backend, *kwargs) File "/home/ybliang/local/miniconda3/envs/pt/lib/python3.10/site-packages/einx/lru_cache.py", line 20, in inner return func(args, kwargs) File "/home/ybliang/local/miniconda3/envs/pt/lib/python3.10/site-packages/einx/lru_cache.py", line 45, in construct_graph output_tracers = func(*args, kwargs, backend=einx.backend.tracer) File "/home/ybliang/local/miniconda3/envs/pt/lib/python3.10/site-packages/einx/op/rearrange.py", line 118, in rearrange exprs_in, exprs_out = parse(description, *[einx.param.get_shape(tensor) for tensor in tensors], cse=cse, *parameters) File "/home/ybliang/local/miniconda3/envs/pt/lib/python3.10/site-packages/einx/lru_cache.py", line 20, in inner return func(args, kwargs) File "/home/ybliang/local/miniconda3/envs/pt/lib/python3.10/site-packages/einx/op/rearrange.py", line 56, in parse exprs = einx.expr.solve( File "/home/ybliang/local/miniconda3/envs/pt/lib/python3.10/site-packages/einx/expr/util.py", line 108, in solve exprs1, exprs2 = stage3.solve(exprs1, exprs2) File "/home/ybliang/local/miniconda3/envs/pt/lib/python3.10/site-packages/einx/expr/stage3.py", line 333, in solve exprs1 = [map(root) if not root is None else None for root in exprs1] File "/home/ybliang/local/miniconda3/envs/pt/lib/python3.10/site-packages/einx/expr/stage3.py", line 333, in exprs1 = [map(root) if not root is None else None for root in exprs1] File "/home/ybliang/local/miniconda3/envs/pt/lib/python3.10/site-packages/einx/expr/stage3.py", line 324, in map return List([map(child) for child in expr.children]) File "/home/ybliang/local/miniconda3/envs/pt/lib/python3.10/site-packages/einx/expr/stage3.py", line 324, in return List([map(child) for child in expr.children]) File "/home/ybliang/local/miniconda3/envs/pt/lib/python3.10/site-packages/einx/expr/stage3.py", line 330, in map return Composition.maybe(map(expr.inner)) File "/home/ybliang/local/miniconda3/envs/pt/lib/python3.10/site-packages/einx/expr/stage3.py", line 324, in map return List([map(child) for child in expr.children]) File "/home/ybliang/local/miniconda3/envs/pt/lib/python3.10/site-packages/einx/expr/stage3.py", line 65, in init Expression.init(self, np.prod([c.value for c in children]).astype(int)) File "/home/ybliang/local/miniconda3/envs/pt/lib/python3.10/site-packages/einx/expr/stage3.py", line 9, in init raise TypeError(f"Expected int, got {type(value)}") TypeError: Expected int, got <class 'numpy.ndarray'>

* torch.distributed related graph break 1:

[rank3]:V0306 22:17:59.702000 139967781041280 torch/_dynamo/symbolic_convert.py:516] [0/0] [graph_breaks] Graph break: from user code at: [rank3]:V0306 22:17:59.702000 139967781041280 torch/_dynamo/symbolic_convert.py:516] [0/0] [__graph_breaks] File "/home/ybliang/local/pytorch/torch/_dynamo/external_utils.py", line 36, in inner [rank3]:V0306 22:17:59.702000 139967781041280 torch/_dynamo/symbolic_convert.py:516] [0/0] [graph_breaks] return fn(*args, kwargs) [rank3]:V0306 22:17:59.702000 139967781041280 torch/_dynamo/symbolic_convert.py:516] [0/0] [graph_breaks] File "/home/ybliang/local/pytorch/torch/nn/modules/module.py", line 1536, in _call_impl [rank3]:V0306 22:17:59.702000 139967781041280 torch/_dynamo/symbolic_convert.py:516] [0/0] [graph_breaks] return forward_call(*args, **kwargs) [rank3]:V0306 22:17:59.702000 139967781041280 torch/_dynamo/symbolic_convert.py:516] [0/0] [graph_breaks] File "/home/ybliang/local/pytorch/torch/nn/parallel/distributed.py", line 1589, in forward [rank3]:V0306 22:17:59.702000 139967781041280 torch/_dynamo/symbolic_convert.py:516] [0/0] [graph_breaks] inputs, kwargs = self._pre_forward(*inputs, kwargs) [rank3]:V0306 22:17:59.702000 139967781041280 torch/_dynamo/symbolic_convert.py:516] [0/0] [graph_breaks] File "/home/ybliang/local/pytorch/torch/nn/parallel/distributed.py", line 1464, in _pre_forward [rank3]:V0306 22:17:59.702000 139967781041280 torch/_dynamo/symbolic_convert.py:516] [0/0] [__graph_breaks] self.reducer.prepare_for_forward() [rank3]:V0306 22:17:59.702000 139967781041280 torch/_dynamo/symbolic_convert.py:516] [0/0] [graph_breaks] Traceback (most recent call last): ...... [rank3]:V0306 22:17:59.702000 139967781041280 torch/_dynamo/symbolic_convert.py:516] [0/0] [graph_breaks] File "/home/ybliang/local/pytorch/torch/_dynamo/symbolic_convert.py", line 674, in call_function [rank3]:V0306 22:17:59.702000 139967781041280 torch/_dynamo/symbolic_convert.py:516] [0/0] [graph_breaks] self.push(fn.call_function(self, args, kwargs)) [rank3]:V0306 22:17:59.702000 139967781041280 torch/_dynamo/symbolic_convert.py:516] [0/0] [graph_breaks] File "/home/ybliang/local/pytorch/torch/_dynamo/variables/user_defined.py", line 687, in call_function [rank3]:V0306 22:17:59.702000 139967781041280 torch/_dynamo/symbolic_convert.py:516] [0/0] [graph_breaks] return self.call_method(tx, "call", args, kwargs) [rank3]:V0306 22:17:59.702000 139967781041280 torch/_dynamo/symbolic_convert.py:516] [0/0] [graph_breaks] File "/home/ybliang/local/pytorch/torch/_dynamo/variables/user_defined.py", line 579, in call_method [rank3]:V0306 22:17:59.702000 139967781041280 torch/_dynamo/symbolic_convert.py:516] [0/0] [graph_breaks] return super().call_method(tx, name, args, kwargs) [rank3]:V0306 22:17:59.702000 139967781041280 torch/_dynamo/symbolic_convert.py:516] [0/0] [graph_breaks] File "/home/ybliang/local/pytorch/torch/_dynamo/variables/base.py", line 371, in call_method [rank3]:V0306 22:17:59.702000 139967781041280 torch/_dynamo/symbolic_convert.py:516] [0/0] [graph_breaks] raise unimplemented(f"call_method {self} {name} {args} {kwargs}") [rank3]:V0306 22:17:59.702000 139967781041280 torch/_dynamo/symbolic_convert.py:516] [0/0] [graph_breaks] File "/home/ybliang/local/pytorch/torch/_dynamo/exc.py", line 190, in unimplemented [rank3]:V0306 22:17:59.702000 139967781041280 torch/_dynamo/symbolic_convert.py:516] [0/0] [__graph_breaks] raise Unsupported(msg) [rank3]:V0306 22:17:59.702000 139967781041280 torch/_dynamo/symbolic_convert.py:516] [0/0] [graph_breaks] torch._dynamo.exc.Unsupported: call_method UserDefinedObjectVariable(instancemethod) call [] {}

Graph break 2:

[rank4]:V0306 22:18:13.253000 139785056863360 torch/_dynamo/symbolic_convert.py:516] [32/0] [graph_breaks] Graph break: from user code at: [rank4]:V0306 22:18:13.253000 139785056863360 torch/_dynamo/symbolic_convert.py:516] [32/0] [graph_breaks] File "/data/users/ybliang/debug/empathy/ring-attention-pytorch/ring_attention_pytorch/ring_attention.py", line 228, in sharded_seq_to_sharded_batch [rank4]:V0306 22:18:13.253000 139785056863360 torch/_dynamo/symbolic_convert.py:516] [32/0] [graphbreaks] logits, = all_gather(logits) [rank4]:V0306 22:18:13.253000 139785056863360 torch/_dynamo/symbolic_convert.py:516] [32/0] [graph_breaks] File "/home/ybliang/local/pytorch/torch/nn/modules/module.py", line 1536, in _call_impl [rank4]:V0306 22:18:13.253000 139785056863360 torch/_dynamo/symbolic_convert.py:516] [32/0] [graph_breaks] return forward_call(*args, **kwargs) [rank4]:V0306 22:18:13.253000 139785056863360 torch/_dynamo/symbolic_convert.py:516] [32/0] [graph_breaks] File "/data/users/ybliang/debug/empathy/ring-attention-pytorch/ring_attention_pytorch/distributed.py", line 85, in forward [rank4]:V0306 22:18:13.253000 139785056863360 torch/_dynamo/symbolic_convert.py:516] [32/0] [graph_breaks] return AllGatherFunction.apply(x, self.dim, sizes) [rank4]:V0306 22:18:13.253000 139785056863360 torch/_dynamo/symbolic_convert.py:516] [32/0] [__graph_breaks] File "/data/users/ybliang/debug/empathy/ring-attention-pytorch/ring_attention_pytorch/distributed.py", line 68, in forward [rank4]:V0306 22:18:13.253000 139785056863360 torch/_dynamo/symbolic_convert.py:516] [32/0] [graph_breaks] x, batch_sizes = all_gather_variable_dim(x, dim = dim, sizes = sizes) [rank4]:V0306 22:18:13.253000 139785056863360 torch/_dynamo/symbolic_convert.py:516] [32/0] [graph_breaks] File "/data/users/ybliang/debug/empathy/ring-attention-pytorch/ring_attention_pytorch/distributed.py", line 42, in all_gather_variable_dim [rank4]:V0306 22:18:13.253000 139785056863360 torch/_dynamo/symbolic_convert.py:516] [32/0] [graph_breaks] sizes = gather_sizes(t, dim = dim) [rank4]:V0306 22:18:13.253000 139785056863360 torch/_dynamo/symbolic_convert.py:516] [32/0] [graph_breaks] File "/data/users/ybliang/debug/empathy/ring-attention-pytorch/ring_attention_pytorch/distributed.py", line 32, in gather_sizes [rank4]:V0306 22:18:13.253000 139785056863360 torch/_dynamo/symbolic_convert.py:516] [32/0] [graph_breaks] sizes = all_gather_same_dim(size) [rank4]:V0306 22:18:13.253000 139785056863360 torch/_dynamo/symbolic_convert.py:516] [32/0] [graph_breaks] File "/data/users/ybliang/debug/empathy/ring-attention-pytorch/ring_attention_pytorch/distributed.py", line 27, in all_gather_same_dim [rank4]:V0306 22:18:13.253000 139785056863360 torch/_dynamo/symbolic_convert.py:516] [32/0] [graph_breaks] dist.all_gather(gathered_tensors, t) [rank4]:V0306 22:18:13.253000 139785056863360 torch/_dynamo/symbolic_convert.py:516] [32/0] [graph_breaks] File "/home/ybliang/local/pytorch/torch/distributed/_functional_collectives.py", line 1036, in all_gather_inplace [rank4]:V0306 22:18:13.253000 139785056863360 torch/_dynamo/symbolic_convert.py:516] [32/0] [__graph_breaks] output = all_gather_tensor(tensor, 0, group, tag) [rank4]:V0306 22:18:13.253000 139785056863360 torch/_dynamo/symbolic_convert.py:516] [32/0] [graph_breaks] File "/home/ybliang/local/pytorch/torch/distributed/_functional_collectives.py", line 227, in all_gather_tensor [rank4]:V0306 22:18:13.253000 139785056863360 torch/_dynamo/symbolic_convert.py:516] [32/0] [graph_breaks] group_name = _resolve_group_name(group, tag) [rank4]:V0306 22:18:13.253000 139785056863360 torch/_dynamo/symbolic_convert.py:516] [32/0] [graph_breaks] File "/home/ybliang/local/pytorch/torch/distributed/_functional_collectives.py", line 758, in _resolve_group_name [rank4]:V0306 22:18:13.253000 139785056863360 torch/_dynamo/symbolic_convert.py:516] [32/0] [graph_breaks] raise ValueError(f"Unsupported group type: {type(group)}, {group}") [rank4]:V0306 22:18:13.253000 139785056863360 torch/_dynamo/symbolic_convert.py:516] [32/0] [graph_breaks] Traceback (most recent call last): ...... [rank4]:V0306 22:17:59.812000 139785056863360 torch/_dynamo/symbolic_convert.py:516] [1/0] [graph_breaks] File "/home/ybliang/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1219, in CALL_FUNCTION [rank4]:V0306 22:17:59.812000 139785056863360 torch/_dynamo/symbolic_convert.py:516] [1/0] [graph_breaks] self.call_function(fn, args, {}) [rank4]:V0306 22:17:59.812000 139785056863360 torch/_dynamo/symbolic_convert.py:516] [1/0] [graph_breaks] File "/home/ybliang/local/pytorch/torch/_dynamo/symbolic_convert.py", line 674, in call_function [rank4]:V0306 22:17:59.812000 139785056863360 torch/_dynamo/symbolic_convert.py:516] [1/0] [graph_breaks] self.push(fn.call_function(self, args, kwargs)) [rank4]:V0306 22:17:59.812000 139785056863360 torch/_dynamo/symbolic_convert.py:516] [1/0] [graph_breaks] File "/home/ybliang/local/pytorch/torch/_dynamo/variables/builtin.py", line 719, in call_function [rank4]:V0306 22:17:59.812000 139785056863360 torch/_dynamo/symbolic_convert.py:516] [1/0] [graph_breaks] return super().call_function(tx, args, kwargs) [rank4]:V0306 22:17:59.812000 139785056863360 torch/_dynamo/symbolic_convert.py:516] [1/0] [graph_breaks] File "/home/ybliang/local/pytorch/torch/_dynamo/variables/base.py", line 352, in call_function [rank4]:V0306 22:17:59.812000 139785056863360 torch/_dynamo/symbolic_convert.py:516] [1/0] [graph_breaks] unimplemented(f"call_function {self} {args} {kwargs}") [rank4]:V0306 22:17:59.812000 139785056863360 torch/_dynamo/symbolic_convert.py:516] [1/0] [graph_breaks] File "/home/ybliang/local/pytorch/torch/_dynamo/exc.py", line 190, in unimplemented [rank4]:V0306 22:17:59.812000 139785056863360 torch/_dynamo/symbolic_convert.py:516] [1/0] [graph_breaks] raise Unsupported(msg) [rank4]:V0306 22:17:59.812000 139785056863360 torch/_dynamo/symbolic_convert.py:516] [1/0] [__graph_breaks] torch._dynamo.exc.Unsupported: call_function BuiltinVariable(ValueError) [ConstantVariable(str: "Unsupported group type: <class 'NoneType'>, None")] {}



### Versions

N/A

cc @ezyang @msaroufim @bdhirsh @anijain2305 @zou3519 @chauhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng
kabachuha commented 7 months ago

@yanboliang Please, open the issue there https://github.com/lucidrains/ring-attention-pytorch as well, @lucidrains has the experience with the code

lucidrains commented 7 months ago

@kabachuha @yanboliang hey Yanbo, thanks for your interest in ring attention.

could you try 0.3.4? reverted back to using einops for now

edit: I think you can close this issue regardless, as it has nothing to do with pytorch.

anijain2305 commented 1 month ago

@yanboliang any update on this issue?

yanboliang commented 1 month ago

Ring attention works well with torch.compile now.