threefoldtecharchive / jumpscaleX_core

Apache License 2.0
1 stars 6 forks source link

better resource mgmt with gedis #637

Closed xmonader closed 4 years ago

xmonader commented 4 years ago

Previous work of @zaibon and @grimpy I'm posting for reference and the spots they found

diff --git a/JumpscaleCore/servers/gedis/handlers.py b/JumpscaleCore/servers/gedis/handlers.py
index 6a88d11e..2c44597a 100644
--- a/JumpscaleCore/servers/gedis/handlers.py
+++ b/JumpscaleCore/servers/gedis/handlers.py
@@ -1,14 +1,16 @@
-from Jumpscale import j
-from redis.exceptions import ConnectionError
+import binascii
+
 import nacl
-from .protocol import RedisCommandParser, RedisResponseWriter
 from nacl.signing import VerifyKey
-import binascii
+from redis.exceptions import ConnectionError

-JSBASE = j.baseclasses.object
+from Jumpscale import j

+from .protocol_new import Disconnect, Error, ProtocolHandler
 from .UserSession import UserSession, UserSessionAdmin

+JSBASE = j.baseclasses.object
+

 def _command_split(cmd, author3botname="zerobot", packagename="base"):
     """
@@ -162,67 +164,6 @@ class Request:
         return self.headers.get("response_type", "auto").casefold()

-class ResponseWriter:
-    """
-    ResponseWriter is an object that expose methods
-    to write data back to the client
-    """
-
-    def __init__(self, socket):
-        self._socket = socket
-        self._writer = RedisResponseWriter(socket)
-
-    def write(self, value):
-        self._writer.encode(value)
-
-    def error(self, value):
-        if isinstance(value, dict):
-            value = j.data.serializers.json.dumps(value)
-        self._writer.error(value)
-
-
-class GedisSocket:
-    """
-    GedisSocket encapsulate the raw tcp socket
-    when you want to read the next request on the socket,
-    call the `read` method, it will return a Request object
-    when you want to write back to the client
-    call get_writer to get ReponseWriter
-    """
-
-    def __init__(self, socket):
-        self._socket = socket
-        self._parser = RedisCommandParser(socket)
-        self._writer = ResponseWriter(self._socket)
-
-    def read(self):
-        """
-        call this method when you want to process the next request
-
-        :return: return a Request
-        :rtype: tuple
-        """
-        raw_request = self._parser.read_request()
-        if not raw_request:
-            raise j.exceptions.Value("malformatted request")
-        return Request(raw_request)
-
-    @property
-    def writer(self):
-        return self._writer
-
-    def on_disconnect(self):
-        """
-        make sur to always call this method before closing the socket
-        """
-        if self._parser:
-            self._parser.on_disconnect()
-
-    @property
-    def closed(self):
-        return self._socket.closed
-
-
 class Handler(JSBASE):
     def __init__(self, gedis_server):
         JSBASE.__init__(self)
@@ -230,53 +171,36 @@ class Handler(JSBASE):
         self.cmds = {}  # caching of commands
         # will hold classes of type GedisCmds,key is the self.gedis_server._actorkey_get(
         self.cmds_meta = self.gedis_server.cmds_meta
+        self._protocol = ProtocolHandler()

     def handle_gedis(self, socket, address):

-        # BUG: if we start a server with kosmos --debug it should get in the debugger but it does not if errors trigger, maybe something in redis?
-        # w=self.t
-        # raise j.exceptions.Base("d")
-        gedis_socket = GedisSocket(socket)
-
+        stream = socket.makefile("rwb")
         user_session = UserSessionAdmin()
-
-        try:
-            self._handle_gedis_session(gedis_socket, address, user_session=user_session)
-        except Exception as e:
-            gedis_socket.on_disconnect()
-            self._log_error("connection closed: %s" % str(e), context="%s:%s" % address, exception=e)
-
-    def _handle_gedis_session(self, gedis_socket, address, user_session=None):
-        """
-        deal with 1 specific session
-        :param socket:
-        :param address:
-        :param parser:
-        :param response:
-        :return:
-        """
-        self._log_info("new incoming connection", context="%s:%s" % address)
-
         while True:
             try:
-                request = gedis_socket.read()
-            except ConnectionError as err:
-                self._log_info("connection read error: %s" % str(err), context="%s:%s" % address)
-                # close the connection
+                data = self._protocol.handle_request(stream)
+            except Disconnect:
+                self._log_info("Client went away: %s:%s" % address)
+                stream.close()
+                socket.close()
                 return

-            logdict, result = self._handle_request(request, address, user_session=user_session)
-
-            if logdict:
-                gedis_socket.writer.error(logdict)
             try:
-                gedis_socket.writer.write(result)
-
-            except ConnectionError as err:
-                self._log_info("connection error: %s" % str(err), context="%s:%s" % address)
-                # close the connection
+                request = Request(data)
+                self._log_info(f"request command {request.command}", context="%s:%s" % address)
+                logdict, resp = self._handle_request(request, address, user_session=user_session)
+                if logdict:
+                    resp = Error(logdict)
+            except Exception as err:
+                self._log_error("Unexpected error %s: %s:%s" % (str(err), *address))
+                resp = Error(err.args[0])
+                stream.close()
+                socket.close()
                 return

+            self._protocol.write_response(stream, resp)
+
     def _authorized(self, cmd_obj, user_session):
         """
         checks if the current session is authorized to access the requested command
diff --git a/JumpscaleCore/servers/gedis/protocol_new.py b/JumpscaleCore/servers/gedis/protocol_new.py
new file mode 100644
index 00000000..ac673eec
--- /dev/null
+++ b/JumpscaleCore/servers/gedis/protocol_new.py
@@ -0,0 +1,98 @@
+from collections import namedtuple
+from io import BytesIO
+
+
+class CommandError(Exception):
+    pass
+
+
+class Disconnect(Exception):
+    pass
+
+
+Error = namedtuple("Error", ("message",))
+
+
+class ProtocolHandler(object):
+    def __init__(self):
+        self.handlers = {
+            b"+": self.handle_simple_string,
+            b"-": self.handle_error,
+            b":": self.handle_integer,
+            b"$": self.handle_string,
+            b"*": self.handle_array,
+            b"%": self.handle_dict,
+        }
+
+    def handle_request(self, stream):
+        first_byte = stream.read(1)
+        if not first_byte:
+            raise Disconnect()
+
+        try:
+            # Delegate to the appropriate handler based on the first byte.
+            return self.handlers[first_byte](stream)
+        except KeyError:
+            raise CommandError("bad request")
+
+    def handle_simple_string(self, stream):
+        return stream.readline().rstrip(b"\r\n")
+
+    def handle_error(self, stream):
+        return Error(stream.readline().rstrip(b"\r\n"))
+
+    def handle_integer(self, stream):
+        return int(stream.readline().rstrip(b"\r\n"))
+
+    def handle_string(self, stream):
+        # First read the length ($<length>\r\n).
+        length = int(stream.readline().rstrip(b"\r\n"))
+        if length == -1:
+            return None  # Special-case for NULLs.
+        length += 2  # Include the trailing \r\n in count.
+        return stream.read(length)[:-2]
+
+    def handle_array(self, stream):
+        num_elements = int(stream.readline().rstrip(b"\r\n"))
+        return [self.handle_request(stream) for _ in range(num_elements)]
+
+    def handle_dict(self, stream):
+        num_items = int(stream.readline().rstrip(b"\r\n"))
+        elements = [self.handle_request(stream) for _ in range(num_items * 2)]
+        return dict(zip(elements[::2], elements[1::2]))
+
+    def write_response(self, stream, data):
+        buf = BytesIO()
+        self._write(buf, data)
+        buf.seek(0)
+        stream.write(buf.getvalue())
+        stream.flush()
+
+    def _write(self, buf, data):
+        if isinstance(data, str):
+            data = data.encode("utf-8")
+
+        if isinstance(data, bytes):
+            buf.write(b"$%d\r\n%s\r\n" % (len(data), data))
+        elif isinstance(data, int):
+            buf.write(b":%d\r\n" % data)
+        elif isinstance(data, Error):
+            buf.write(b"-%s\r\n" % data.message.encode("utf-8"))
+        elif isinstance(data, (list, tuple)):
+            buf.write(b"*%d\r\n" % len(data))
+            for item in data:
+                self._write(buf, item)
+        elif isinstance(data, dict):
+            buf.write("%%%d\r\n" % len(data))
+            for key in data:
+                self._write(buf, key)
+                self._write(buf, data[key])
+        elif data is None:
+            buf.write(b"$-1\r\n")
+        else:
+            raise CommandError("unrecognized type: %s" % type(data))
+
+    def _write_buffer(self, buf, data):
+        if isinstance(data, str):
+            data = data.encode()
+        buf.write(data)
xmonader commented 4 years ago

will be resolved in secrethandshake branch