aleph-im / aleph-sdk-python

Python SDK library for the Aleph.im network
MIT License
2 stars 4 forks source link

Implement `VmConfidentialClient` class #138

Closed nesitor closed 2 weeks ago

nesitor commented 2 weeks ago

Problem: A user cannot initialize an already created confidential VM.

Solution: Implement VmConfidentialClient class to be able to initialize and interact with confidential VMs.

github-actions[bot] commented 2 weeks ago

Summary: The PR modifies a significant number of files across multiple directories. It includes a refactoring of the 'vmclient.py' file into 'vm_client.py', adds a new 'vm_confidential_client.py' file, and renames several other files. This indicates a high level of complexity and potential impact on the codebase.

Highlight: The diff shows the renaming of 'vmclient.py' to 'vm_client.py' and the addition of a new 'vm_confidential_client.py' file. This suggests a significant refactoring and new feature implementation.

diff  --git a/pyproject.toml b/pyproject.toml
index b52efe66..1070a7f7 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -32,6 +32,7  @@ dependencies = [
      "python-magic",
      "typer",
      "typing_extensions",
+     "aioresponses>=0.7.6"
  ]

  [project.optional-dependencies]
diff  --git a/src/aleph/sdk/client/vmclient.py b/src/aleph/sdk/client/vm_client.py
similarity index 100%
rename from src/aleph/sdk/client/vmclient.py
rename to src/aleph/sdk/client/vm_client.py
diff  --git a/src/aleph/sdk/client/vm_confidential_client.py b/src/aleph/sdk/client/vm_confidential_client.py
new file mode 100644
index 00000000..305ff8ef
--- /dev/null
+++ b/src/aleph/sdk/client/vm_confidential_client.py
@@ -0,0 +1,155 @@
+import json
+import logging
+import tempfile
+from pathlib import Path
+from typing import Any, Dict, Optional, Tuple
+
+import aiohttp
+from aleph_message.models import ItemHash
+
+from aleph.sdk.client.vm_client import VmClient
+from aleph.sdk.types import Account
+from aleph.sdk.utils import run_in_subprocess
+
+logger = logging.getLogger(__name__)
+
+
+class VmConfidentialClient(VmClient):
+    sevctl_path: Path
+
+    def __init__(
+        self,
+        account: Account,
+        sevctl_path: Path,
+        node_url: str = "",
+        session: Optional[aiohttp.ClientSession] = None,
+    ):
+        super().__init__(account, node_url, session)
+        self.sevctl_path = sevctl_path
+
+    async def get_certificates(self) -> Tuple[Optional[int], str]:
+        url = f"{self.node_url}/about/certificates"
+        try:
+            async with self.session.get(url) as response:
+                data = await response.read()
+                with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
+                    tmp_file.write(data)
+                    return response.status, tmp_file.name
+
+        except aiohttp.ClientError as e:
+            logger.error(
+                f"HTTP error getting node certificates on {self.node_url}: {str(e)}"
+            )
+            return None, str(e)
+
+    async def create_session(
+        self, vm_id: ItemHash, certificate_path: Path, policy: int
+    ):
+        args = [
+            "session",
+            "--name",
+            vm_id,
+            str(certificate_path),
+            str(policy),
+        ]
+        try:
+            await self.sevctl_cmd(args)
+        except Exception as e:
+            raise ValueError(f"Session creation have failed, reason: {str(e)}")
+
+    async def initialize(
+        self, vm_id: ItemHash, session: Path, godh: Path
+    ) -> Tuple[Optional[int], str]:
+        session_file = session.read_bytes()
+        godh_file = godh.read_bytes()
+        params = {
+            "session": session_file,
+            "godh": godh_file,
+        }
+        return await self.perform_confidential_operation(
+            vm_id, "confidential/initialize", params=params
+        )
+
+    async def measurement(self, vm_id: ItemHash) -> Tuple[Optional[int], str]:
+        status, text = await self.perform_confidential_operation(
+            vm_id, "confidential/measurement"
+        )
+        if status:
+            response = json.loads(text)
+            return status, response
+
+        return status, text
+
+    async def validate_measurement(self, vm_id: ItemHash) -> bool:
+        return True
+
+    async def build_secret(
+        self, tek_path: Path, tik_path: Path, measurement: str, secret: str
+    ) -> Tuple[Path, Path]:
+        current_path = Path().cwd()
+        secret_header_path = current_path / "secret_header.bin"
+        secret_payload_path = current_path / "secret_payload.bin"
+        args = [
+            "secret",
+            "build",
+            "--tik",
+            str(tik_path),
+            "--tek",
+            str(tek_path),
+            "--launch-measure-blob",
+            measurement,
+            "--secret",
+            secret,
+            str(secret_header_path),
+            str(secret_payload_path),
+        ]
+        try:
+            await self.sevctl_cmd(args)
+            return secret_header_path, secret_payload_path
+        except Exception as e:
+            raise ValueError(f"Secret building have failed, reason: {str(e)}")
+
+    async def inject_secret(
+        self, vm_id: ItemHash, packed_header: str, secret: str
+    ) -> Tuple[Optional[int], str]:
+        params = {
+            "packed_header": packed_header,
+            "secret": secret,
+        }
+        status, text = await self.perform_confidential_operation(
+            vm_id, "confidential/inject_secret", params=params
+        )
+
+        if status:
+            response = json.loads(text)
+            return status, response
+
+        return status, text
+
+    async def perform_confidential_operation(
+        self, vm_id: ItemHash, operation: str, params: Optional[Dict[str, Any]] = None
+    ) -> Tuple[Optional[int], str]:
+        if not self.pubkey_signature_header:
+            self.pubkey_signature_header = await self._generate_pubkey_signature_header()
+
+        url, header = await self._generate_header(vm_id=vm_id, operation=operation)
+
+        try:
+            async with self.session.post(url, headers=header, data=params) as response:
+                response_text = await response.text()
+                return response.status, response_text
+
+        except aiohttp.ClientError as e:
+            logger.error(f"HTTP error during operation {operation}: {str(e)}")
+            return None, str(e)
+
+    async def sevctl_cmd(self, *args) -> bytes:
+        return await run_in_subprocess(
+            ["sevctl", *args],
+            check=True,
+        )
diff --git a/src/aleph/sdk/utils.py b/src/aleph/sdk/utils.py
index 2d1b30c7..130edc38 100644
--- a/src/aleph/sdk/utils.py
+++ b/src/aleph/sdk/utils.py
@@ -1,8 +1,10 @@
 import errno
 import hashlib
 import json
 import logging
 import os
+import subprocess
 from datetime import date, datetime, time
 from enum import Enum
 from pathlib import Path
@@ -11,6 +13,7 @@ def sign_vm_control_payload(payload: Dict[str, str], ephemeral_key) -> str:
          }
      )
     return signed_operation
+
+
+async def run_in_subprocess(
+    command: List[str], check: bool = True, stdin_input: Optional[bytes] = None
+) -> bytes:
+    logger.debug(f"command: {'  '.join(command)}")
+
+    process = await asyncio.create_subprocess_exec(
+        *command,
+        stdin=asyncio.subprocess.PIPE,
+        stdout=asyncio.subprocess.PIPE,
+        stderr=asyncio.subprocess.PIPE,
+    )
+    stdout, stderr = await process.communicate(input=stdin_input)
+
+    if check and process.returncode:
+        logger.error(
+            f"Command failed with error code {process.returncode}:\n"
+            f"    stdin = {stdin_input}\n"
+            f"    command = {command}\n"
+            f"    stdout = {stderr}"
+        )
+        raise subprocess.CalledProcessError(
+            process.returncode, str(command), stderr.decode()
+        )
+
+    return stdout

The PR introduces significant changes to the codebase that require deep understanding of the project architecture. As such, a 'BLACK' rating is recommended.