bigscience-workshop / petals

🌸 Run LLMs at home, BitTorrent-style. Fine-tuning and inference up to 10x faster than offloading
https://petals.dev
MIT License
8.89k stars 490 forks source link

Can not use direct server-to-server communication #550

Closed miaoqijun closed 3 months ago

miaoqijun commented 6 months ago

Hello, I am learning the implementation of the inference part and trying to run it on my private swarm. I find that server-to-server communication appears to be unavailable (keeping the option use_server_to_server true by default in the client's config).

The way I find the problem I add some print() in petals/server/handler.py/_iterate_inference_steps():295 just like:

if anext_task in done:
    request = await anext_task
    anext_task = None
    print(f'anext_task done first')
elif get_push_task in done:
    request = await get_push_task
    get_push_task = None
    print(f'get_push_task done first')

to determine if the client or previous server provides inputs to the current server. The results are always anext_task done first even if I add some delay on the client hoping inputs from the previous server can arrive earlier.

Possible Cause Then I read the relevant code and found the cause: For the initialization, the client only passes next_servers in metadata after the first step. However, the servers always use the metadata in the first step to determine which server they are going to put output to (in petals/server/handler.py/rpc_inference():188) :

if can_push:    # the metadata here is from the first step and never update during the session
    task = asyncio.create_task(self._push_outputs(request, output_tensors[0], metadata))

So the servers always get an empty next_server list in _push_outputs().

Then I find the metadata in the following steps are extracted in petals/server/block_functions.py/iterate_rpc_inference():162 but not used:

async for request, step_metadata in input_iterator:    # step_metadate is unused in the function

So I made some modifications locally in the following parts:

petals/server/block_functions.py/iterate_rpc_inference():227

yield output_tensors, can_push, step_metadata

petals/server/handler.py/rpc_inference():174

async for output_tensors, can_push, step_metadata in iterate_rpc_inference(

petals/server/handler.py/rpc_inference():189

task = asyncio.create_task(self._push_outputs(request, output_tensors[0], step_metadata))

After that, I run it on my private swarm and find the server-to-server communication available.

Question I only read a part of the code and did not get a full understanding of the whole system, so I would like to know if my modifications were right or if you implemented it in this way on purpose. Thank you very much for your review and reply.

cybershrapnel commented 3 months ago

i think they abandoned open source for the paid model lol

mryab commented 3 months ago

@cybershrapnel, I'm not sure where you got this impression, but Petals has no paid model and we have no intention of developing it :) It's true that our current team of maintainers is very short on people, but we are trying to develop additional features for the library and are happy to support contributions (e.g., PRs) from the community

mryab commented 3 months ago

@miaoqijun Thanks for the observation! Tagging @justheuristic just to be sure, but it looks like a very good catch: it might be that we overlooked the way metadata is built when making multiple inference steps. We'll try to look into it

justheuristic commented 3 months ago

Hi, @mryab , @miaoqijun

lemme look into this, will write back soon

justheuristic commented 3 months ago

Okay, so the problem OP described is certainly still there. And it's a shame that it took us so long to get to that :sweat_smile:

@miaoqijun , thank you a lot for the work you've done when writing this issue.

For reproducibility, here's how i tested it:

  1. After this line: https://github.com/bigscience-workshop/petals/blob/c08d09c/src/petals/server/handler.py#L324 , i added print("checking if should push to next_servers:", next_servers)

note: the commit id points to the main branch as of now

  1. test setup
    
    # terminal 1 - initial peer
    python -m petals.cli.run_dht --identity_path tests/bootstrap.id --host_maddrs /ip4/127.0.0.1/tcp/31337
    # terminals 2 and 3 - repeat the same script
    python -m petals.cli.run_server $MODEL_NAME --num_blocks 4 --device cpu --torch_dtype float32 --initial_peers $INITIAL_PEERS

terminal 4: run inference test

pytest test_full_model.py::test_full_model_exact_match -s



Outputs from the first server match with what @miaoqijun reported earlier:
![image](https://github.com/bigscience-workshop/petals/assets/3491902/f5ad3d14-a790-43a2-bf24-d79b2892989d)

Note that the client *knows* the next servers during the first request - but it withholds them because processing prefix with pushes can be invalid in some cases (e.g. if client wishes to modify intermediate activations via prefix tuning).

Out of the two alternative solutions (either send `next_servers` in the first request or what @miaoqijun proposed ) the latter is more general because it covers cases where `next_servers` changed during inference (e.g. next server experienced hardware failure).

I will now reopen this as a PR and find a way to properly credit @miaoqijun in that pull request
miaoqijun commented 3 months ago

Thx for fixing this and happy with my contribution although it‘s a bit late : )