kyegomez / LongNet

Implementation of plug in and play Attention from "LongNet: Scaling Transformers to 1,000,000,000 Tokens"
https://discord.gg/qUtxnK2NMf
Apache License 2.0
673 stars 64 forks source link

LongNetTransformer Error #25

Open LiJiahao-Alex opened 2 months ago

LiJiahao-Alex commented 2 months ago

I ran the example program and got the following error.

import torch
from long_net.model import LongNetTransformer

longnet = LongNetTransformer(
    num_tokens=20000,
    dim=512,
    depth=6,
    dim_head=64,
    heads=8,
    ff_mult=4,
).to("cuda:0")

tokens = torch.randint(0, 20000, (1, 512)).to("cuda:0")
logits = longnet(tokens)
print(logits)

It looks like there's something wrong internally?

2024-07-08 01:43:03.002114: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-07-08 01:43:03.048251: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-07-08 01:43:03.679049: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
2024-07-08 01:43:04,742 - numexpr.utils - INFO - Note: detected 96 virtual cores but NumExpr set to maximum of 64, check "NUMEXPR_MAX_THREADS" environment variable.
2024-07-08 01:43:04,742 - numexpr.utils - INFO - Note: NumExpr detected 96 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
2024-07-08 01:43:04,742 - numexpr.utils - INFO - NumExpr defaulting to 8 threads.
Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda
Traceback (most recent call last):
  File "/workspace/DeepVQ/model/LongNetGPT.py", line 20, in <module>
    logits = longnet(tokens)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/long_net/model.py", line 302, in forward
    x = self.transformer(x)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/long_net/model.py", line 271, in forward
    x = block(x) + x
RuntimeError: The size of tensor a (256) must match the size of tensor b (512) at non-singleton dimension 1

Process finished with exit code 1

Upvote & Fund

Fund with Polar

SanderMoon commented 1 month ago

Having the same issue, python version 3.11. I believe it has to do with the skip connections implemented in the Transformer class' forward() function. The first ParallelTransformerBlock block has a default dilation rate of 2, meaning it will produce half the output tokens (256 vs 512 in normal transformers in this example). You can check this using the other example using DilatedAttention. This also means you cannot add the skip-connection because now there is a discrepancy between the sequence dimension of the original input and the dilated output.

I'm not sure what the intention was of the original authors though.

Steps to reproduce:

  1. Create a new venv python3.11 -m venv venv and activate source ./venv/bin/activate
  2. install long net as described by the readme: pip install longnet
  3. Run the code described in the Readme for the LongNetTransformer:
import torch
from long_net.model import LongNetTransformer

longnet = LongNetTransformer(
    num_tokens=20000,
    dim=512,
    depth=6,
    dim_head=64,
    heads=8,
    ff_mult=4,
)

tokens = torch.randint(0, 20000, (1, 512))
logits = longnet(tokens)
print(logits)

Produces:

Traceback (most recent call last):
  File "/Users/sander.moonemans/Study/Thesis/Longnet/longnet_transformer.py", line 14, in <module>
    logits = longnet(tokens)
             ^^^^^^^^^^^^^^^
  File "/Users/sander.moonemans/Study/Thesis/Longnet/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/sander.moonemans/Study/Thesis/Longnet/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/sander.moonemans/Study/Thesis/Longnet/venv/lib/python3.11/site-packages/long_net/model.py", line 302, in forward
    x = self.transformer(x)
        ^^^^^^^^^^^^^^^^^^^
  File "/Users/sander.moonemans/Study/Thesis/Longnet/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/sander.moonemans/Study/Thesis/Longnet/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/sander.moonemans/Study/Thesis/Longnet/venv/lib/python3.11/site-packages/long_net/model.py", line 271, in forward
    x = block(x) + x
        ~~~~~~~~~^~~
RuntimeError: The size of tensor a (256) must match the size of tensor b (512) at non-singleton dimension 1