bigscience-workshop / petals

🌸 Run LLMs at home, BitTorrent-style. Fine-tuning and inference up to 10x faster than offloading
https://petals.dev
MIT License
8.89k stars 489 forks source link

Upstream Changes makes the demo not work #536

Open hrQAQ opened 8 months ago

hrQAQ commented 8 months ago

I'm following the tutorial Run-Petals-server-on-Windows to start up a server on my own PC. Upon running python -m petals.cli.run_server petals-team/StableBeluga2, I encountered the following error:

(base) horik@asus:~$ python -m petals.cli.run_server petals-team/StableBeluga2
Nov 07 16:05:49.690 [INFO] Running Petals 2.3.0.dev0
Nov 07 16:05:52.285 [INFO] Make sure you follow the LLaMA's terms of use: https://bit.ly/llama2-license for LLaMA 2, https://bit.ly/llama-license for LLaMA 1
Nov 07 16:05:52.285 [INFO] Using DHT prefix: StableBeluga2-hf
Nov 07 16:06:09.845 [INFO] This server is accessible via relays
Nov 07 16:06:15.377 [INFO] Connecting to the public swarm
Nov 07 16:06:15.378 [INFO] Running a server on ['/ip4/192.168.185.162/tcp/40783/p2p/12D3KooWAoVXAq9YkSmYVASCmGwLeRhZeKWJqUjFF5CURnREvqU1', '/ip4/127.0.0.1/tcp/40783/p2p/12D3KooWAoVXAq9YkSmYVASCmGwLeRhZeKWJqUjFF5CURnREvqU1', '/ip6/::1/tcp/46511/p2p/12D3KooWAoVXAq9YkSmYVASCmGwLeRhZeKWJqUjFF5CURnREvqU1']
Nov 07 16:06:15.612 [INFO] Model weights are loaded in bfloat16, quantized to nf4 format
Nov 07 16:06:15.619 [INFO] Server will fill your GPU memory with 5 transformer blocks. If you want to leave some free GPU memory, please specify a lesser --num_blocks manually
Nov 07 16:06:15.620 [INFO] Attention cache for all blocks will consume up to 0.31 GiB
Nov 07 16:06:15.620 [INFO] Loading throughput info
Nov 07 16:06:15.620 [INFO] Measuring network and compute throughput. This takes about a minute and will be cached for future runs
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/home/horik/miniconda3/lib/python3.11/site-packages/petals/cli/run_server.py", line 235, in <module>
    main()
  File "/home/horik/miniconda3/lib/python3.11/site-packages/petals/cli/run_server.py", line 219, in main
    server = Server(
             ^^^^^^^
  File "/home/horik/miniconda3/lib/python3.11/site-packages/petals/server/server.py", line 237, in __init__
    throughput_info = get_server_throughput(
                      ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/horik/miniconda3/lib/python3.11/site-packages/petals/server/throughput.py", line 82, in get_server_throughput
    cache[cache_key] = measure_throughput_info(
                       ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/horik/miniconda3/lib/python3.11/site-packages/petals/server/throughput.py", line 122, in measure_throughput_info
    "inference_rps": measure_compute_rps(
                     ^^^^^^^^^^^^^^^^^^^^
  File "/home/horik/miniconda3/lib/python3.11/site-packages/petals/server/throughput.py", line 210, in measure_compute_rps
    _, cache = block.forward(dummy_input, use_cache=True)  # Skip the 1st step to exclude the initialization time
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/horik/miniconda3/lib/python3.11/site-packages/tensor_parallel/tensor_parallel.py", line 99, in forward
    return [self.module_shards[0](*args, **kwargs)][self.output_device_index]
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/horik/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/horik/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/horik/miniconda3/lib/python3.11/site-packages/petals/models/llama/block.py", line 48, in forward
    attention_mask = LlamaModel._prepare_decoder_attention_mask(
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: type object 'LlamaModel' has no attribute '_prepare_decoder_attention_mask'

After STFW, I found that the root cause of this error report mab be related to the upstream refactored the attention_mask module and the related commit page is here.

I propose there are two possible solutions to this issue. The first one is to specify the download of a previous version of the 'transformers' library when installing dependencies. The second solution is to adapt to the new attention mask implementation(needs some modification of petals/models/llama/block.py).

hrQAQ commented 8 months ago

changelog from transformers Release 4.35.0

Attention mask refactor

We refactored the attention mask logic for major models in transformers. For instance, we removed padding_mask argument which was ambiguous for some users

Remove ambiguous padding_mask and instead use a 2D->4D Attn Mask Mapper by @patrickvonplaten in https://github.com/huggingface/transformers/pull/26792 [Attention Mask] Refactor all encoder-decoder attention mask by @patrickvonplaten in https://github.com/huggingface/transformers/pull/27086

By running the following commands, my windows server is successfully started.

pip uninstall transformers
pip install transformers==4.34.0

So, I suggest modifying the setup.cfg to make minimal changes.

justheuristic commented 8 months ago

Thank you for reporting this. I'm fixing transformers version for now. In the meantime, we're working on fixing the problem in upstream