ultravioletrs / cocos

Cocos AI - Confidential Computing System for AI
https://ultraviolet.rs/cocos.html
Apache License 2.0
23 stars 9 forks source link

NOISSUE - V-Sock reconnect for agent #215

Closed SammyOina closed 1 month ago

SammyOina commented 1 month ago

What type of PR is this?

This is a feature

What does this do?

Which issue(s) does this PR fix/relate to?

None

Have you included tests for your changes?

No

Did you document any new/modified feature?

No

Notes

to test use this modified version of test/agent-config

// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0

// Simplified script to pass configs to agent without manager and read logs and events for manager.
// This tool is meant for testing purposes.
package main

import (
    "encoding/json"
    "encoding/pem"
    "fmt"
    "log"
    "net"
    "os"
    "strconv"
    "time"

    "github.com/mdlayher/vsock"
    "github.com/ultravioletrs/cocos/agent"
    "github.com/ultravioletrs/cocos/internal"
    "github.com/ultravioletrs/cocos/manager"
    "github.com/ultravioletrs/cocos/manager/qemu"
    pkgmanager "github.com/ultravioletrs/cocos/pkg/manager"
    "golang.org/x/exp/rand"
    "google.golang.org/protobuf/proto"
)

func main() {
    if len(os.Args) < 5 {
        log.Fatalf("usage: %s <data-path> <algo-path> <public-key-path> <attested-tls-bool>", os.Args[0])
    }
    dataPath := os.Args[1]
    algoPath := os.Args[2]
    pubKeyFile := os.Args[3]
    attestedTLSParam, err := strconv.ParseBool(os.Args[4])
    if err != nil {
        log.Fatalf("usage: %s <data-path> <algo-path> <public-key-path> <attested-tls-bool>, <attested-tls-bool> must be a bool value", os.Args[0])
    }
    attestedTLS := attestedTLSParam

    pubKey, err := os.ReadFile(pubKeyFile)
    if err != nil {
        log.Fatalf(fmt.Sprintf("failed to read public key file: %s", err))
    }
    pubPem, _ := pem.Decode(pubKey)
    algoHash, err := internal.Checksum(algoPath)
    if err != nil {
        log.Fatalf(fmt.Sprintf("failed to calculate checksum: %s", err))
    }
    dataHash, err := internal.Checksum(dataPath)
    if err != nil {
        log.Fatalf(fmt.Sprintf("failed to calculate checksum: %s", err))
    }

    ac := agent.Computation{
        ID:              "123",
        Datasets:        agent.Datasets{agent.Dataset{Hash: [32]byte(dataHash), UserKey: pubPem.Bytes}},
        Algorithm:       agent.Algorithm{Hash: [32]byte(algoHash), UserKey: pubPem.Bytes},
        ResultConsumers: []agent.ResultConsumer{{UserKey: pubPem.Bytes}},
        AgentConfig: agent.AgentConfig{
            LogLevel:    "debug",
            Port:        "7002",
            AttestedTls: attestedTLS,
        },
    }
    // Send agent config (keep this part as is)
    if err := SendAgentConfig(3, ac); err != nil {
        log.Printf("Error sending agent config: %v", err)
    }
    for {
        l, err := vsock.Listen(manager.ManagerVsockPort, nil)
        if err != nil {
            log.Printf("Error creating listener: %v", err)
            time.Sleep(5 * time.Second) // Wait for 5 seconds before retrying
            continue
        }

        log.Println("Listener started")

        // Start a goroutine to handle connections
        done := make(chan struct{})
        go func() {
            for {
                conn, err := l.Accept()
                if err != nil {
                    log.Printf("Error accepting connection: %v", err)
                    close(done)
                    return
                }
                go handleConnections(conn)
            }
        }()

        // Simulate disconnection after a random interval between 10 and 30 seconds
        disconnectAfter := time.Duration(10+rand.Intn(21)) * time.Second
        select {
        case <-time.After(disconnectAfter):
            log.Println("Simulating disconnection")
            l.Close()
        case <-done:
            // The listener has already closed due to an error
        }

        // Wait for 5 seconds before reconnecting
        log.Println("Waiting 5 seconds before reconnecting")
        time.Sleep(5 * time.Second)
    }
}

func SendAgentConfig(cid uint32, ac agent.Computation) error {
    conn, err := vsock.Dial(cid, qemu.VsockConfigPort, nil)
    if err != nil {
        return err
    }
    defer conn.Close()
    payload, err := json.Marshal(ac)
    if err != nil {
        return err
    }

    var ac2 agent.Computation
    if err := json.Unmarshal(payload, &ac2); err != nil {
        return err
    }
    if _, err := conn.Write(payload); err != nil {
        return err
    }
    return nil
}

func handleConnections(conn net.Conn) {
    defer conn.Close()
    for {
        b := make([]byte, 1024)
        n, err := conn.Read(b)
        if err != nil {
            log.Println(err)
            return
        }
        var message pkgmanager.ClientStreamMessage
        if err := proto.Unmarshal(b[:n], &message); err != nil {
            log.Println(err)
            return
        }
        fmt.Println(message.String())
    }
}
drasko commented 1 month ago

We need also to buffer logs in the case of Manager vsock connection absence, and send them on recconect

SammyOina commented 1 month ago

We need also to buffer logs in the case of Manager vsock connection absence, and send them on recconect

this is done in #222