Synss / python-mbedtls

Cryptographic library with an mbed TLS back end
MIT License
79 stars 28 forks source link

Python handle of "record from another epoch: expected 1, received 0" #88

Closed derhex3r closed 1 year ago

derhex3r commented 1 year ago

I am submitting a …

Description

When DTLS clients use the same port and try to reconnect appears mbedtls debug log error which I can not handle on the python level. Is there some way to intercept this error in python code? Or somehow make the server handle it like an epoch 0 message instead of just dropping it?

ssl_msg.c:3581: record from another epoch: expected 1, received 0
ssl_msg.c:3628: possible client reconnect from the same port
ssl_msg.c:3321: mbedtls_ssl_check_dtls_clihlo_cookie() returned -27264 (-0x6a80)
ssl_msg.c:3325: sending HelloVerifyRequest
ssl_msg.c:3332: ssl->f_send() returned 60 (-0xffffffc4)
ssl_msg.c:4496: ssl_check_client_reconnect() returned 0 (-0x0000)
ssl_msg.c:4505: discarding unexpected record (header)

Current behavior

We are able to see this error message in the mbedtls debug log but there is no error rising in the python code.

Expected behavior

Rise some Python error that can be handled correctly. In the best case - make mbedtls server handle the same client hello message again as the epoch 0 message without resending it from the client.

Steps to reproduce

  1. Run the attached server-side code
  2. Run the attached client-side code and stop it with ctrl+C
  3. Run the attached client-side code again.
  4. "record from another epoch: expected 1, received 0" will appear in the debug log.

Minimal demo of the problem

Server-side code:

from __future__ import annotations

import time
import threading
import socket
from contextlib import suppress

import datetime
import logging
import ipaddress
import select

from mbedtls._tls import _enable_debug_output, _set_debug_level  # type: ignore
from mbedtls.exceptions import TLSError
from mbedtls.tls import (
    ServerContext,
    HelloVerifyRequest,
    DTLSConfiguration,
)

def run_dtls_server( srv_conf,
        address: str,
        port: int):

    try:
        ip = ipaddress.ip_address(address)
        if ip.version == 4:
            _sock = ServerContext(srv_conf).wrap_socket(
                socket.socket(socket.AF_INET, socket.SOCK_DGRAM),
            )

        elif ip.version == 6:
            _sock = ServerContext(srv_conf).wrap_socket(
                socket.socket(socket.AF_INET6, socket.SOCK_DGRAM),
            )
    except ValueError:
        print(f"The address '{address}' is not a valid IP address.")

    _enable_debug_output(_sock.context)
    _set_debug_level(2)

    _sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
    _sock.bind((address, port))

    sockets = [_sock]

    while True:
        ready_to_read, _, _ = select.select(sockets, [], [])

        for sock in ready_to_read:
            if sock is _sock:
                assert _sock
                conn, addr = _sock.accept()

                print("Connection from: ", addr)
                # random_cookie = generate_random_ascii(10)

                conn.setcookieparam(addr[0].encode("ascii"))
                # conn.setcookieparam(random_cookie)

                try:
                    with suppress(HelloVerifyRequest):
                        conn.do_handshake()
                except Exception as e:
                    print("Handshake 1 Exception: ", e)
                    print("Error type = ", type(e))

                _, (conn, addr) = conn, conn.accept()
                _.close()

                conn.setcookieparam(addr[0].encode("ascii"))
                # conn.setcookieparam(random_cookie)

                try:
                    conn.do_handshake()
                except Exception as e:
                    print("Handshake 2 Exception: ", e)
                    print("Error type = ", type(e))
                    continue

                sockets.append(conn)

            else:
                data, client_address = sock.recvfrom(4096)
                print("From:", client_address, "Received data:", data)
                msg = "srv_"+data.decode("utf-8")+" "+str(datetime.datetime.now())
                sock.send(msg.encode("utf-8"))

srv_address = ("127.0.0.1", 9000)

conf = DTLSConfiguration(
    pre_shared_key_store={'Client_identity': b'SECRET'},
    validate_certificates=False,
    handshake_timeout_min=99999,
    handshake_timeout_max=99999,
)

stop_event = threading.Event()

while True:     
   is_server_alive = run_dtls_server(conf, address=srv_address[0], port=srv_address[1])
   if is_server_alive is False:
       print("Server was crached, restarting...")

Client-side code:

from __future__ import annotations

import socket
import time
from mbedtls.tls import DTLSConfiguration
import ipaddress

from mbedtls._tls import _enable_debug_output, _set_debug_level  # type: ignore
from mbedtls.exceptions import TLSError
from mbedtls.tls import (
    ClientContext,
    DTLSConfiguration,
)

def run_dtls_client(
    cli_conf,
    srv_address: str,
    srv_port: int,
    srv_hostname,
    cli_address = "0.0.0.0",
    cli_port = 4000,
    ): 
    try:
        ip = ipaddress.ip_address(cli_address)
        if ip.version == 4:
            _sock = ClientContext(cli_conf).wrap_socket(
                socket.socket(socket.AF_INET, socket.SOCK_DGRAM),
                server_hostname=srv_hostname,
            )
        elif ip.version == 6:
            _sock = ClientContext(cli_conf).wrap_socket(
                socket.socket(socket.AF_INET6, socket.SOCK_DGRAM),
                server_hostname=srv_hostname,
            )

    except ValueError:
        print(f"The address '{cli_address}' is not a valid IP address.")

    _sock.bind((cli_address, cli_port))
    _sock.connect((srv_address, srv_port))

    _enable_debug_output(_sock.context)
    _set_debug_level(2)    

    _sock.do_handshake()

    while True:
        _sock.send(b"client1_data")
        incoming_data = _sock.recv(4096)
        print(f"Client received: {incoming_data}")
        time.sleep(1)

dtls_cli_conf = DTLSConfiguration(
    pre_shared_key=('Client_identity', b'SECRET'),
    validate_certificates=False,
    handshake_timeout_min=99999,
    handshake_timeout_max=99999,
    ciphers=["TLS-PSK-WITH-AES-256-CCM"],
)

srv_address = ("127.0.0.1", 9000)

run_dtls_client(
    cli_conf=dtls_cli_conf,
    cli_address="0.0.0.0",
    cli_port=4000,
    srv_address=srv_address[0],
    srv_port=srv_address[1],
    srv_hostname="localhost",
)
Synss commented 1 year ago

Thank you for your feedback.

The "record from another epoch" message indeed comes from https://github.com/Mbed-TLS/mbedtls/blob/development/library/ssl_msg.c#L3678-L3684 and the corresponding comment mentions that this should be handled in the caller.

I'll have to look into this. Also note that I'd gratefully accept a PR as well.

Synss commented 1 year ago

@derhex3r: Thank you for your contribution again! I've been thinking about this. As the code I've quoted above shows, there is not way to rise this error to Python as it doesn't even leave the C function where it's been generated. I wouldn't know how to handle it in my TLSWrappedSocket either as the socket already does what it's supposed to do, transparently.

I think the handling for this use case has to occur in user code, that is, the server you write using these libraries.

In the present case, the server should just close and reset the TLSWrappedSocket or TLSWrappedBuffer.

derhex3r commented 1 year ago

Thanks! Do I understand that right that if I will just TLSWrappedSocket I will lose gotten ClientHello message? Do you have some ideas about how can I establish a new TLSWrappedSocket with this ClientHello message if I will just close the old one?

Synss commented 1 year ago

Unless I misunderstood, the best is to reset the connection and restart the handshake.