Closed xmonader closed 4 years ago
Previous work of @zaibon and @grimpy I'm posting for reference and the spots they found
diff --git a/JumpscaleCore/servers/gedis/handlers.py b/JumpscaleCore/servers/gedis/handlers.py index 6a88d11e..2c44597a 100644 --- a/JumpscaleCore/servers/gedis/handlers.py +++ b/JumpscaleCore/servers/gedis/handlers.py @@ -1,14 +1,16 @@ -from Jumpscale import j -from redis.exceptions import ConnectionError +import binascii + import nacl -from .protocol import RedisCommandParser, RedisResponseWriter from nacl.signing import VerifyKey -import binascii +from redis.exceptions import ConnectionError -JSBASE = j.baseclasses.object +from Jumpscale import j +from .protocol_new import Disconnect, Error, ProtocolHandler from .UserSession import UserSession, UserSessionAdmin +JSBASE = j.baseclasses.object + def _command_split(cmd, author3botname="zerobot", packagename="base"): """ @@ -162,67 +164,6 @@ class Request: return self.headers.get("response_type", "auto").casefold() -class ResponseWriter: - """ - ResponseWriter is an object that expose methods - to write data back to the client - """ - - def __init__(self, socket): - self._socket = socket - self._writer = RedisResponseWriter(socket) - - def write(self, value): - self._writer.encode(value) - - def error(self, value): - if isinstance(value, dict): - value = j.data.serializers.json.dumps(value) - self._writer.error(value) - - -class GedisSocket: - """ - GedisSocket encapsulate the raw tcp socket - when you want to read the next request on the socket, - call the `read` method, it will return a Request object - when you want to write back to the client - call get_writer to get ReponseWriter - """ - - def __init__(self, socket): - self._socket = socket - self._parser = RedisCommandParser(socket) - self._writer = ResponseWriter(self._socket) - - def read(self): - """ - call this method when you want to process the next request - - :return: return a Request - :rtype: tuple - """ - raw_request = self._parser.read_request() - if not raw_request: - raise j.exceptions.Value("malformatted request") - return Request(raw_request) - - @property - def writer(self): - return self._writer - - def on_disconnect(self): - """ - make sur to always call this method before closing the socket - """ - if self._parser: - self._parser.on_disconnect() - - @property - def closed(self): - return self._socket.closed - - class Handler(JSBASE): def __init__(self, gedis_server): JSBASE.__init__(self) @@ -230,53 +171,36 @@ class Handler(JSBASE): self.cmds = {} # caching of commands # will hold classes of type GedisCmds,key is the self.gedis_server._actorkey_get( self.cmds_meta = self.gedis_server.cmds_meta + self._protocol = ProtocolHandler() def handle_gedis(self, socket, address): - # BUG: if we start a server with kosmos --debug it should get in the debugger but it does not if errors trigger, maybe something in redis? - # w=self.t - # raise j.exceptions.Base("d") - gedis_socket = GedisSocket(socket) - + stream = socket.makefile("rwb") user_session = UserSessionAdmin() - - try: - self._handle_gedis_session(gedis_socket, address, user_session=user_session) - except Exception as e: - gedis_socket.on_disconnect() - self._log_error("connection closed: %s" % str(e), context="%s:%s" % address, exception=e) - - def _handle_gedis_session(self, gedis_socket, address, user_session=None): - """ - deal with 1 specific session - :param socket: - :param address: - :param parser: - :param response: - :return: - """ - self._log_info("new incoming connection", context="%s:%s" % address) - while True: try: - request = gedis_socket.read() - except ConnectionError as err: - self._log_info("connection read error: %s" % str(err), context="%s:%s" % address) - # close the connection + data = self._protocol.handle_request(stream) + except Disconnect: + self._log_info("Client went away: %s:%s" % address) + stream.close() + socket.close() return - logdict, result = self._handle_request(request, address, user_session=user_session) - - if logdict: - gedis_socket.writer.error(logdict) try: - gedis_socket.writer.write(result) - - except ConnectionError as err: - self._log_info("connection error: %s" % str(err), context="%s:%s" % address) - # close the connection + request = Request(data) + self._log_info(f"request command {request.command}", context="%s:%s" % address) + logdict, resp = self._handle_request(request, address, user_session=user_session) + if logdict: + resp = Error(logdict) + except Exception as err: + self._log_error("Unexpected error %s: %s:%s" % (str(err), *address)) + resp = Error(err.args[0]) + stream.close() + socket.close() return + self._protocol.write_response(stream, resp) + def _authorized(self, cmd_obj, user_session): """ checks if the current session is authorized to access the requested command diff --git a/JumpscaleCore/servers/gedis/protocol_new.py b/JumpscaleCore/servers/gedis/protocol_new.py new file mode 100644 index 00000000..ac673eec --- /dev/null +++ b/JumpscaleCore/servers/gedis/protocol_new.py @@ -0,0 +1,98 @@ +from collections import namedtuple +from io import BytesIO + + +class CommandError(Exception): + pass + + +class Disconnect(Exception): + pass + + +Error = namedtuple("Error", ("message",)) + + +class ProtocolHandler(object): + def __init__(self): + self.handlers = { + b"+": self.handle_simple_string, + b"-": self.handle_error, + b":": self.handle_integer, + b"$": self.handle_string, + b"*": self.handle_array, + b"%": self.handle_dict, + } + + def handle_request(self, stream): + first_byte = stream.read(1) + if not first_byte: + raise Disconnect() + + try: + # Delegate to the appropriate handler based on the first byte. + return self.handlers[first_byte](stream) + except KeyError: + raise CommandError("bad request") + + def handle_simple_string(self, stream): + return stream.readline().rstrip(b"\r\n") + + def handle_error(self, stream): + return Error(stream.readline().rstrip(b"\r\n")) + + def handle_integer(self, stream): + return int(stream.readline().rstrip(b"\r\n")) + + def handle_string(self, stream): + # First read the length ($<length>\r\n). + length = int(stream.readline().rstrip(b"\r\n")) + if length == -1: + return None # Special-case for NULLs. + length += 2 # Include the trailing \r\n in count. + return stream.read(length)[:-2] + + def handle_array(self, stream): + num_elements = int(stream.readline().rstrip(b"\r\n")) + return [self.handle_request(stream) for _ in range(num_elements)] + + def handle_dict(self, stream): + num_items = int(stream.readline().rstrip(b"\r\n")) + elements = [self.handle_request(stream) for _ in range(num_items * 2)] + return dict(zip(elements[::2], elements[1::2])) + + def write_response(self, stream, data): + buf = BytesIO() + self._write(buf, data) + buf.seek(0) + stream.write(buf.getvalue()) + stream.flush() + + def _write(self, buf, data): + if isinstance(data, str): + data = data.encode("utf-8") + + if isinstance(data, bytes): + buf.write(b"$%d\r\n%s\r\n" % (len(data), data)) + elif isinstance(data, int): + buf.write(b":%d\r\n" % data) + elif isinstance(data, Error): + buf.write(b"-%s\r\n" % data.message.encode("utf-8")) + elif isinstance(data, (list, tuple)): + buf.write(b"*%d\r\n" % len(data)) + for item in data: + self._write(buf, item) + elif isinstance(data, dict): + buf.write("%%%d\r\n" % len(data)) + for key in data: + self._write(buf, key) + self._write(buf, data[key]) + elif data is None: + buf.write(b"$-1\r\n") + else: + raise CommandError("unrecognized type: %s" % type(data)) + + def _write_buffer(self, buf, data): + if isinstance(data, str): + data = data.encode() + buf.write(data)
will be resolved in secrethandshake branch
Previous work of @zaibon and @grimpy I'm posting for reference and the spots they found