lucidrains / ring-attention-pytorch

Implementation of 💍 Ring Attention, from Liu et al. at Berkeley AI, in Pytorch
MIT License
474 stars 27 forks source link

ValueError: Invalid expression '[ True]', must be integers #3

Closed kyegomez closed 8 months ago

kyegomez commented 8 months ago

pytree_node instead.
  _torch_pytree._register_pytree_node(
/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/transformers/utils/generic.py:309: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  _torch_pytree._register_pytree_node(
Traceback (most recent call last):
  File "/Users/defalt/Desktop/Athena/research/Gemini/gemini_block.py", line 18, in <module>
    out = model(x)  # Apply the model to the input tensor
          ^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/defalt/Desktop/Athena/research/Gemini/gemini_torch/model.py", line 101, in forward
    x = self.attn(x)
        ^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/ring_attention_pytorch/ring_attention.py", line 228, in forward
    q, k, v = rearrange('b n (qkv h d) -> qkv b h n d', qkv, qkv = 3, h = self.heads)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/einx/lru_cache.py", line 70, in inner
    graph = construct_graph(*args, backend=backend, **kwargs)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/einx/lru_cache.py", line 20, in inner
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/einx/lru_cache.py", line 45, in construct_graph
    output_tracers = func(*args, **kwargs, backend=einx.backend.tracer)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/einx/op/rearrange.py", line 118, in rearrange
    exprs_in, exprs_out = parse(description, *[einx.param.get_shape(tensor) for tensor in tensors], cse=cse, **parameters)
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/einx/lru_cache.py", line 20, in inner
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/einx/op/rearrange.py", line 59, in parse
    + [einx.expr.Equation(k, np.asarray(v)[..., np.newaxis], depth1=None, depth2=None) for k, v in parameters.items()],
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/einx/op/rearrange.py", line 59, in <listcomp>
    + [einx.expr.Equation(k, np.asarray(v)[..., np.newaxis], depth1=None, depth2=None) for k, v in parameters.items()],
       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/einx/expr/util.py", line 36, in __init__
    self.expr2 = _input_expr(expr2)
                 ^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/einx/expr/util.py", line 29, in _input_expr
    raise ValueError(f"Invalid expression '{expr}', must be integers")
ValueError: Invalid expression '[ True]', must be integers
``` 
lucidrains commented 8 months ago

@kyegomez ah yea, maybe einx is still a bit new

reverted back to einops for now

fferflo commented 8 months ago

I might overlook something, but it's possible the error stems from passing heads=True rather than an int to RingAttention in https://github.com/kyegomez/Gemini/blob/main/gemini_torch/long_gemini.py#L65 and might appear with einops also.

lucidrains commented 8 months ago

@fferflo oh hey Florian! yeah this is my bad, should have checked instead of assuming

let me revert back to einx, been using it in a couple projects now and so far so good!

fferflo commented 8 months ago

Hey Phil, you're totally good, I was just wondering if I had missed a bug somewhere and thought I'd share my findings here.

I love that you're finding it helpful!

lucidrains commented 8 months ago

@fferflo dinner's on me if you ever come to SF! 🤣 i can invite Alex too (he's in south san fran)

fferflo commented 8 months ago

That would be really awesome, although it's quite the trip from Germany so I'm not sure when I'll be able to cash in on the dinner 😄

kyegomez commented 8 months ago

@lucidrains @fferflo no invite 😔