exo-explore / exo

Run your own AI cluster at home with everyday devices 📱💻 🖥️⌚
GNU General Public License v3.0
6.36k stars 329 forks source link

Windows progress #72

Open small-cactus opened 1 month ago

small-cactus commented 1 month ago

I got it to start on windows and detect other devices, however the windows pc itself is not being shown on other devices running exo, it gets detected as nothing, [] according to debug. On the windows PC it shows itself as unknown device at 0TFLOPS.

If anyone has an idea on how to get it to run, that'd be pretty cool.

Updated main.py to allow cross platform with everything:

import argparse
import asyncio
import signal
import uuid
import os
from typing import List
from exo.orchestration.standard_node import StandardNode
from exo.networking.grpc.grpc_server import GRPCServer
from exo.networking.grpc.grpc_discovery import GRPCDiscovery
from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
from exo.api import ChatGPTAPI
from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_inference_engine, get_system_info

# parse args
parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")
parser.add_argument("--node-id", type=str, default=str(uuid.uuid4()), help="Node ID")
parser.add_argument("--node-host", type=str, default="0.0.0.0", help="Node host")
parser.add_argument("--node-port", type=int, default=None, help="Node port")
parser.add_argument("--listen-port", type=int, default=5678, help="Listening port for discovery")
parser.add_argument("--prometheus-client-port", type=int, default=None, help="Prometheus client port")
parser.add_argument("--broadcast-port", type=int, default=5678, help="Broadcast port for discovery")
parser.add_argument("--discovery-timeout", type=int, default=30, help="Discovery timeout in seconds")
parser.add_argument("--wait-for-peers", type=int, default=0, help="Number of peers to wait to connect to before starting")
parser.add_argument("--chatgpt-api-port", type=int, default=8000, help="ChatGPT API port")
parser.add_argument("--chatgpt-api-response-timeout-secs", type=int, default=90, help="ChatGPT API response timeout in seconds")
parser.add_argument("--max-generate-tokens", type=int, default=256, help="Max tokens to generate in each request")
parser.add_argument("--inference-engine", type=str, default=None, help="Inference engine to use")
parser.add_argument("--disable-tui", action=argparse.BooleanOptionalAction, help="Disable TUI")
args = parser.parse_args()

print_yellow_exo()

system_info = get_system_info()
print(f"Detected system: {system_info}")

inference_engine_name = args.inference_engine or ("mlx" if system_info == "Apple Silicon Mac" else "tinygrad")
inference_engine = get_inference_engine(inference_engine_name)
print(f"Using inference engine: {inference_engine.__class__.__name__}")

if args.node_port is None:
    args.node_port = find_available_port(args.node_host)
    if DEBUG >= 1: print(f"Using available port: {args.node_port}")

discovery = GRPCDiscovery(args.node_id, args.node_port, args.listen_port, args.broadcast_port, discovery_timeout=args.discovery_timeout)
node = StandardNode(args.node_id, None, inference_engine, discovery, partitioning_strategy=RingMemoryWeightedPartitioningStrategy(), chatgpt_api_endpoint=f"http://localhost:{args.chatgpt_api_port}/v1/chat/completions", web_chat_url=f"http://localhost:{args.chatgpt_api_port}", disable_tui=args.disable_tui, max_generate_tokens=args.max_generate_tokens)
server = GRPCServer(node, args.node_host, args.node_port)
node.server = server
api = ChatGPTAPI(node, inference_engine.__class__.__name__, response_timeout_secs=args.chatgpt_api_response_timeout_secs)
node.on_token.register("main_log").on_next(lambda _, tokens , __: print(inference_engine.tokenizer.decode(tokens) if hasattr(inference_engine, "tokenizer") else tokens))
if args.prometheus_client_port:
    from exo.stats.metrics import start_metrics_server
    start_metrics_server(node, args.prometheus_client_port)

async def shutdown(signal, loop):
    """Gracefully shutdown the server and close the asyncio loop."""
    print(f"Received exit signal {signal.name}...")
    print("Thank you for using exo.")
    print_yellow_exo()
    server_tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
    [task.cancel() for task in server_tasks]
    print(f"Cancelling {len(server_tasks)} outstanding tasks")
    await asyncio.gather(*server_tasks, return_exceptions=True)
    await server.stop()
    loop.stop()

def handle_exit():
    asyncio.ensure_future(shutdown(signal.SIGTERM, asyncio.get_running_loop()))

import logging

logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)

def setup_signal_handlers(loop):
    if os.name == 'posix':
        logger.debug("Setting up signal handlers for POSIX system")
        for s in [signal.SIGINT, signal.SIGTERM]:
            loop.add_signal_handler(s, handle_exit)
    else:
        logger.debug("Setting up signal handlers for non-POSIX system")
        signal.signal(signal.SIGINT, lambda s, f: handle_exit())
        signal.signal(signal.SIGTERM, lambda s, f: handle_exit())

async def main():
    logger.debug(f"Starting main function on {os.name} system")

    logger.debug("Starting node")
    await node.start(wait_for_peers=args.wait_for_peers)

    logger.debug(f"Starting API server on port {args.chatgpt_api_port}")
    api_task = asyncio.create_task(api.run(port=args.chatgpt_api_port))

    logger.debug("Entering main event loop")
    try:
        await asyncio.Event().wait()
    except Exception as e:
        logger.error(f"Error in main event loop: {e}")
    finally:
        logger.debug("Exiting main event loop")

if __name__ == "__main__":
    loop = asyncio.new_event_loop()
    asyncio.set_event_loop(loop)
    try:
        logger.info("Starting application")
        setup_signal_handlers(loop)  # Set up signal handlers before running main
        loop.run_until_complete(main())
    except KeyboardInterrupt:
        logger.info("Received keyboard interrupt. Shutting down...")
    except Exception as e:
        logger.error(f"Unexpected error: {e}")
    finally:
        logger.info("Running shutdown procedure")
        loop.run_until_complete(shutdown(signal.SIGTERM, loop))
        loop.close()
        logger.info("Application shutdown complete")

I didn't really look at any other files to see where the issue might be happening so it might be an easy fix.

Other remarks:

Models will not download on windows

other systems show 1 node when windows shows 2 nodes

inference with any connected or no connected nodes will not work from the windows pc. it works fine when inference is done from another non windows connected system (because it only registers 1 node)

Even though windows will state that it has 2 nodes and is connected, none of the tokens get sent to the windows node when ran from Mac.

TLDR: Nothing works on windows, but it recognizes other systems, it's just that nothing works at all

stephanj commented 1 month ago

Claude Sonnet 3.5 suggestions:

The issues you're encountering are likely related to device capability detection and network discovery. Let's address these problems step by step:

  1. Windows PC not shown on other devices: This could be due to firewall settings or network discovery issues. Here are some steps to troubleshoot:

a) Check Windows Firewall:

b) Network Discovery:

c) Broadcast Messages:

  1. Windows PC showing as unknown device with 0 TFLOPS: This is likely due to the device capability detection not working correctly on Windows. Let's improve the device_capabilities() function for Windows:
import platform
import psutil
import wmi

def device_capabilities():
    if platform.system() == "Windows":
        c = wmi.WMI()
        cpu_info = c.Win32_Processor()[0]
        gpu_info = c.Win32_VideoController()[0]

        model = f"Windows PC ({platform.processor()})"
        chip = f"{cpu_info.Name}, GPU: {gpu_info.Name}"
        memory = psutil.virtual_memory().total // (1024**2)  # Convert to MB

        # Estimate FLOPS (this is a very rough estimate)
        cpu_ghz = float(cpu_info.MaxClockSpeed) / 1000  # Convert MHz to GHz
        cpu_cores = int(cpu_info.NumberOfCores)
        estimated_gflops = cpu_ghz * cpu_cores * 8  # Assume 8 FLOPS per cycle per core

        return DeviceCapabilities(
            model=model,
            chip=chip,
            memory=memory,
            flops=DeviceFlops(fp32=estimated_gflops, fp16=estimated_gflops*2, int8=estimated_gflops*4)
        )
    else:
        # Existing implementation for other platforms
        ...

This implementation provides a rough estimate of FLOPS based on CPU information. For more accurate GPU FLOPS, you might need to use a GPU-specific library or maintain a database of known GPU performances.

  1. Cross-platform compatibility: To ensure better cross-platform compatibility, consider these modifications:

a) Use os.path.join() for all file paths. b) Use platform.system() to check the operating system when making system-specific calls. c) For network operations, consider using socket.getaddrinfo() to get IP addresses, as it works across platforms.

Here's an example of how you might modify the discovery mechanism to work better on Windows:

import socket
import struct

class GRPCDiscovery(Discovery):
    async def task_broadcast_presence(self):
        sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)

        message = json.dumps({
            "type": "discovery",
            "node_id": self.node_id,
            "grpc_port": self.node_port,
            "device_capabilities": self.device_capabilities.to_dict()
        }).encode('utf-8')

        if platform.system() == "Windows":
            # On Windows, we might need to broadcast to all network interfaces
            for interface in socket.getaddrinfo(socket.gethostname(), None):
                if interface[1] == socket.SOCK_DGRAM:
                    broadcast_address = interface[4][0].rsplit('.', 1)[0] + '.255'
                    sock.sendto(message, (broadcast_address, self.broadcast_port))
        else:
            # Existing broadcast logic for other platforms
            sock.sendto(message, ('<broadcast>', self.broadcast_port))

        await asyncio.sleep(self.broadcast_interval)

This modification attempts to broadcast on all available network interfaces on Windows, which might help with discovery issues.

Lastly, ensure that all necessary libraries (like wmi for Windows) are installed on your Windows machine.

These changes should help improve the Windows compatibility of your exo project. Remember to test thoroughly on both Windows and other platforms to ensure these modifications don't introduce new issues. If you're still encountering problems, more detailed logs of the discovery process and network communications would be helpful for further debugging.