exo-explore / exo

Run your own AI cluster at home with everyday devices 📱💻 🖥️⌚
GNU General Public License v3.0
6.29k stars 322 forks source link

No way to cancel #61

Open magnusviri opened 1 month ago

magnusviri commented 1 month ago

There's no way to stop an inference. Reloading or closing the webpage doesn't stop it either. The nodes just keep going.

AlexCheema commented 1 month ago

What would be the desired behaviour? How would you like to stop the inference?

stephanj commented 1 month ago

Implementing a cancellation functionality would involve a few key changes across different parts of the codebase. Here's an approach to implement inference stopping:

  1. Add a cancellation flag: First, we need a way to signal that an inference should be cancelled. We can add a cancellation flag to the Node class.

In exo/orchestration/node.py, update the Node abstract base class:

from asyncio import Event

class Node(ABC):
    # ... existing code ...

    @abstractmethod
    async def cancel_inference(self, request_id: str) -> None:
        pass

    @property
    @abstractmethod
    def cancellation_events(self) -> Dict[str, Event]:
        pass
  1. Implement cancellation in StandardNode: In exo/orchestration/standard_node.py, implement the new method and property:
class StandardNode(Node):
    def __init__(self, ...):
        # ... existing initialization ...
        self._cancellation_events: Dict[str, Event] = {}

    async def cancel_inference(self, request_id: str) -> None:
        if request_id in self._cancellation_events:
            self._cancellation_events[request_id].set()
            # Propagate cancellation to other nodes
            await self.broadcast_cancellation(request_id)

    @property
    def cancellation_events(self) -> Dict[str, Event]:
        return self._cancellation_events

    async def broadcast_cancellation(self, request_id: str) -> None:
        async def send_cancellation_to_peer(peer):
            try:
                await asyncio.wait_for(peer.cancel_inference(request_id), timeout=15.0)
            except asyncio.TimeoutError:
                print(f"Timeout sending cancellation to {peer.id()}")
            except Exception as e:
                print(f"Error sending cancellation to {peer.id()}: {e}")

        await asyncio.gather(*[send_cancellation_to_peer(peer) for peer in self.peers], return_exceptions=True)

    # Update process_prompt and process_tensor methods to check for cancellation
    async def _process_prompt(self, base_shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
        if request_id is None:
            request_id = str(uuid.uuid4())
        self._cancellation_events[request_id] = Event()

        try:
            # ... existing code ...

            while not is_finished:
                if self._cancellation_events[request_id].is_set():
                    print(f"Inference cancelled for request {request_id}")
                    return None

                # ... rest of the processing loop ...

        finally:
            del self._cancellation_events[request_id]

    # Similar changes for _process_tensor
  1. Update PeerHandle and GRPCPeerHandle: In exo/networking/peer_handle.py, add the new method:
class PeerHandle(ABC):
    # ... existing methods ...

    @abstractmethod
    async def cancel_inference(self, request_id: str) -> None:
        pass

In exo/networking/grpc/grpc_peer_handle.py, implement the new method:

class GRPCPeerHandle(PeerHandle):
    # ... existing methods ...

    async def cancel_inference(self, request_id: str) -> None:
        request = node_service_pb2.CancelInferenceRequest(request_id=request_id)
        await self.stub.CancelInference(request)
  1. Update gRPC service: In exo/networking/grpc/node_service.proto, add a new RPC method:
service NodeService {
  // ... existing methods ...
  rpc CancelInference (CancelInferenceRequest) returns (Empty) {}
}

message CancelInferenceRequest {
  string request_id = 1;
}

Then regenerate the gRPC code.

In exo/networking/grpc/grpc_server.py, implement the new method:

class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
    # ... existing methods ...

    async def CancelInference(self, request, context):
        request_id = request.request_id
        await self.node.cancel_inference(request_id)
        return node_service_pb2.Empty()
  1. Update API endpoints: In exo/api/chatgpt_api.py, add a new endpoint to cancel inference:
class ChatGPTAPI:
    # ... existing methods ...

    async def handle_cancel_inference(self, request):
        data = await request.json()
        request_id = data.get('request_id')
        if not request_id:
            return web.json_response({'error': 'Missing request_id'}, status=400)

        await self.node.cancel_inference(request_id)
        return web.json_response({'status': 'Cancellation request sent'})

    async def run(self, host: str = "0.0.0.0", port: int = 8000):
        # ... existing setup ...
        self.app.router.add_post('/v1/chat/cancel', self.handle_cancel_inference)
        # ... rest of the method ...

With these changes, you can implement a cancellation mechanism. When a user wants to stop an inference, they can send a POST request to /v1/chat/cancel with the request_id. This will propagate the cancellation request through the network of nodes, stopping the inference process.

To handle this on the frontend, you would need to add a cancel button or mechanism that sends this cancellation request when activated. The exact implementation would depend on your frontend setup.

Remember to handle the cancellation gracefully in your inference engines and clean up any resources that were being used for the cancelled inference.

This implementation provides a basic framework for cancellation. Depending on your specific needs, you might want to add more sophisticated error handling, cleanup procedures, or status reporting for cancelled inferences.


Let's consider how a user might actually initiate the cancellation of an inference:

  1. Web Interface (TinyChat):

The most user-friendly way to stop an inference would be through the web interface. We can add a "Stop" or "Cancel" button to the chat interface.

In the tinychat/examples/tinychat/index.html file, we could add a button next to the "Send" button:

<button id="sendButton">Send</button>
<button id="stopButton" style="display: none;">Stop</button>

Then in the JavaScript, we'd show the stop button when an inference starts and hide it when it completes:

let currentRequestId = null;

async function sendMessage() {
    // ... existing code ...
    currentRequestId = Date.now().toString(); // Simple unique ID
    document.getElementById('stopButton').style.display = 'inline';
    // ... rest of the function ...
}

async function stopInference() {
    if (currentRequestId) {
        const response = await fetch('/v1/chat/cancel', {
            method: 'POST',
            headers: {
                'Content-Type': 'application/json',
            },
            body: JSON.stringify({ request_id: currentRequestId }),
        });
        if (response.ok) {
            console.log('Inference cancelled');
        } else {
            console.error('Failed to cancel inference');
        }
    }
}

document.getElementById('stopButton').addEventListener('click', stopInference);

// Hide stop button when inference completes
function appendAssistantMessage(message) {
    // ... existing code ...
    document.getElementById('stopButton').style.display = 'none';
}
  1. ChatGPT API:

For users interacting via the ChatGPT API, we should document the new cancellation endpoint. They would need to keep track of the request_id returned when starting an inference, then use it to cancel:

import requests

# Start inference
response = requests.post('http://localhost:8000/v1/chat/completions', json={...})
request_id = response.json()['id']

# Cancel inference
cancel_response = requests.post('http://localhost:8000/v1/chat/cancel', json={'request_id': request_id})
  1. Command Line Interface:

If you have a CLI for interacting with Exo, you could add a cancel command:

exo cancel <request_id>

This would implement the cancellation request to the API.

  1. Automatic Cancellation:

In addition to manual cancellation, we should implement automatic cancellation when a client disconnects:

In exo/api/chatgpt_api.py, modify the streaming response:

async def handle_post_chat_completions(self, request):
    # ... existing code ...

    if stream:
        response = web.StreamResponse(
            status=200,
            reason='OK',
            headers={'Content-Type': 'application/json'},
        )
        await response.prepare(request)

        try:
            # ... streaming code ...
        except ConnectionResetError:
            print(f"Client disconnected. Cancelling inference {request_id}")
            await self.node.cancel_inference(request_id)
        finally:
            await response.write_eof()

    # ... rest of the method ...

This ensures that if a user closes their browser or terminates their connection, the inference is automatically cancelled.

  1. Timeout-based Cancellation:

We could also implement a timeout-based cancellation for inferences that take too long:

async def process_prompt(self, base_shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None, timeout: float = 300) -> Optional[np.ndarray]:
    if request_id is None:
        request_id = str(uuid.uuid4())

    try:
        return await asyncio.wait_for(self._process_prompt(base_shard, prompt, request_id, inference_state), timeout=timeout)
    except asyncio.TimeoutError:
        print(f"Inference {request_id} timed out after {timeout} seconds. Cancelling.")
        await self.cancel_inference(request_id)
        return None

These approaches provide multiple ways for users to stop an inference, either manually through the interface, programmatically through the API, or automatically in case of disconnection or timeout. The specific implementation would depend on your user interface and use cases, but this gives a comprehensive set of options to ensure users can stop inferences when needed.

magnusviri commented 1 month ago

What would be the desired behaviour? How would you like to stop the inference?

Just a button on the web interface.

AlexCheema commented 1 month ago

What would be the desired behaviour? How would you like to stop the inference?

Just a button on the web interface.

Okay, I will make this a Quality of Life upgrade probably together with https://github.com/exo-explore/exo/issues/67