fabriziosalmi / proxmox-vm-autoscale

Automatically scale virtual machines resources on Proxmox hosts
MIT License
153 stars 6 forks source link

ssh_utils.py improvement #11

Open fabriziosalmi opened 4 weeks ago

fabriziosalmi commented 4 weeks ago
import paramiko
import logging
import time
import select

class SSHClient:
    def __init__(self, host, user, password=None, key_path=None):
        """
        Initializes the SSH client with given credentials.
        :param host: Hostname or IP address of the server.
        :param user: Username to connect with.
        :param password: Password for SSH (optional).
        :param key_path: Path to the private SSH key (optional).
        """
        self.host = host
        self.user = user
        self.password = password
        self.key_path = key_path
        self.logger = logging.getLogger("ssh_utils")
        self.client = None

    def connect(self):
        """
        Establish an SSH connection to the host.
        """
        if self.client is None:
            try:
                self.client = paramiko.SSHClient()
                self.client.set_missing_host_key_policy(paramiko.AutoAddPolicy())

                # Connect using password or private key
                if self.password:
                    self.client.connect(self.host, username=self.user, password=self.password, timeout=10)
                elif self.key_path:
                    private_key = paramiko.RSAKey.from_private_key_file(self.key_path)
                    self.client.connect(self.host, username=self.user, pkey=private_key, timeout=10)
                else:
                    raise ValueError("Either password or key_path must be provided for SSH connection.")

                self.logger.info(f"Successfully connected to {self.host}")

            except Exception as e:
                self.logger.error(f"Failed to connect to {self.host}: {str(e)}")
                raise

    def execute_command(self, command, timeout=15):
        """
        Execute a command on the remote server with a timeout.
        :param command: Command to execute.
        :param timeout: Timeout in seconds.
        :return: Output of the command.
        """
        if self.client is None:
            self.connect()

        try:
            stdin, stdout, stderr = self.client.exec_command(command)

            # Timeout logic
            channel = stdout.channel
            ready = select.select([channel], [], [], timeout)[0]

            if ready:
                exit_status = channel.recv_exit_status()

                if exit_status == 0:
                    output = stdout.read().decode('utf-8').strip()
                    self.logger.info(f"Command executed successfully on {self.host}: {command}")
                    return output
                else:
                    error_message = stderr.read().decode('utf-8').strip()
                    self.logger.error(f"Command failed on {self.host}: {command}\nError: {error_message}")
                    raise RuntimeError(f"Command execution failed: {error_message}")
            else:
                # Timeout case
                self.logger.error(f"Command timed out on {self.host}: {command}")
                raise TimeoutError(f"Command '{command}' timed out after {timeout} seconds.")

        except Exception as e:
            self.logger.error(f"Error executing command on {self.host}: {str(e)}")
            # Attempt to reconnect and execute the command again
            self.close()
            self.connect()
            return self.execute_command(command, timeout)  # Retry the command

    def execute_command_with_retry(self, command, timeout=15, retries=3, delay=5):
        """
        Execute a command on the remote server with retries and a timeout.
        :param command: Command to execute.
        :param timeout: Timeout in seconds.
        :param retries: Number of retries.
        :param delay: Delay in seconds between retries.
        :return: Output of the command.
        """
        for attempt in range(retries):
            try:
                return self.execute_command(command, timeout)
            except (TimeoutError, RuntimeError) as e:
                self.logger.warning(f"Attempt {attempt + 1} failed for command '{command}'. Retrying in {delay} seconds.")
                time.sleep(delay)
                self.close()
                self.connect()  # Re-establish the connection before retrying
        # Final attempt
        raise RuntimeError(f"Command '{command}' failed after {retries} attempts.")

    def close(self):
        """
        Close the SSH connection.
        """
        if self.client:
            self.client.close()
            self.client = None  # Reset client to None
            self.logger.info(f"SSH connection closed for {self.host}")

    def __enter__(self):
        """
        Context manager entry.
        """
        self.connect()
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        """
        Context manager exit - ensure the SSH connection is closed.
        """
        self.close()