exo-explore / exo

Run your own AI cluster at home with everyday devices 📱💻 🖥️⌚
GNU General Public License v3.0
10.82k stars 618 forks source link

Docs: Linux Example Script #131

Open da-moon opened 2 months ago

da-moon commented 2 months ago

Description of Request

Reason or Need for Feature

Design / Proposal

Build Information

Additional Context

# !/usr/bin/env python3

from exo.inference.shard import Shard
from exo.networking.peer_handle import PeerHandle
from exo.networking.grpc.grpc_peer_handle import GRPCPeerHandle
from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops
from typing import List
import asyncio
import argparse
import uuid
from transformers import AutoTokenizer

# Define models without MLX-specific references
models = {
    "meta-llama/Meta-Llama-3.1-70B": Shard(model_id="meta-llama/Meta-Llama-3.1-70B", start_layer=0, end_layer=0, n_layers=80),
}
# "Qwen/Qwen2-72B-Instruct": Shard(model_id="Qwen/Qwen2-72B-Instruct", start_layer=0, end_layer=0, n_layers=80),
# "meta-llama/Llama-2-7b-chat-hf": Shard(model_id="meta-llama/Llama-2-7b-chat-hf", start_layer=0, end_layer=0, n_layers=32),

# Use a generic model path
path_or_hf_repo = "meta-llama/Meta-Llama-3.1-70B"
# path_or_hf_repo = "Qwen/Qwen2-72B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(path_or_hf_repo)
# Define peer nodes
# peer1 = GRPCPeerHandle("node1", "localhost:8080", DeviceCapabilities(model="placeholder", chip="placeholder", memory=0, flops=DeviceFlops(fp32=0, fp16=0, int8=0)))

peer2 = GRPCPeerHandle("node2", "localhost:8080", DeviceCapabilities(model="placeholder", chip="placeholder", memory=0, flops=DeviceFlops(fp32=0, fp16=0, int8=0)))
shard = models[path_or_hf_repo]
request_id = str(uuid.uuid4())

async def run_prompt(prompt: str):
    messages = [{"role": "user", "content": prompt}]

    # Check if the tokenizer has a chat template
    if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template is not None:
        try:
            prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        except Exception as e:
            print(f"Error applying chat template: {e}")
            # Fallback to using the raw prompt if the chat template fails
            prompt = messages[0]["content"]
    else:
        # If there's no chat template, use the raw prompt
        prompt = messages[0]["content"]

    # Connect to peers
    await peer2.connect()

    try:
        # Send prompt to the peer
        await asyncio.gather(peer2.send_prompt(shard, prompt, request_id))
    except Exception as e:
        print(f"Error sending prompt: {e}")
        return

    previous_length = 0
    n_tokens = 0
    start_time = asyncio.get_event_loop().time()

    while True:
        try:
            # Get results from the peer
            result2, is_finished2 = await peer2.get_inference_result(request_id)

            combined_result = result2
            is_finished = is_finished2

            updated_string = tokenizer.decode(combined_result)
            n_tokens = len(combined_result)

            print(updated_string[previous_length:], end='', flush=True)
            previous_length = len(updated_string)

            if is_finished:
                print("\nDone")
                break

        except Exception as e:
            print(f"Error getting inference result: {e}")
        await asyncio.sleep(0.1)

    end_time = asyncio.get_event_loop().time()
    print(f"\nDone. Processed {n_tokens} tokens in {end_time - start_time:.2f} seconds ({n_tokens / (end_time - start_time):.2f} tokens/second)")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run prompt")
    parser.add_argument("--prompt", type=str, help="The prompt to run")
    args = parser.parse_args()
    asyncio.run(run_prompt(args.prompt))
AlexCheema commented 2 months ago

Thanks for the detailed issue.

Definitely needed. I'm working on quite a big refactor right now which will improve the experience on Linux, and will update docs afterwards. Here's the working branch if you're interested: https://github.com/exo-explore/exo/pull/124