mfavant / tubekit

NEW PROJECT https://github.com/crust-hub/avant
MIT License
0 stars 0 forks source link

feat:support OpenSSL #18

Closed gaowanlu closed 11 months ago

gaowanlu commented 11 months ago

socket_handler demo

#include <tubekit-log/logger.h>
#include <openssl/ssl.h>
#include <openssl/err.h>

#include "socket/socket_handler.h"
#include "socket/server_socket.h"
#include "thread/auto_lock.h"
#include "utility/singleton.h"
#include "thread/worker_pool.h"
#include "task/task_factory.h"
#include "server/server.h"
#include "hooks/tick.h"
#include "hooks/stop.h"
#include "system/system.h"
#include "connection/connection.h"
#include "connection/connection_mgr.h"
#include "connection/http_connection.h"
#include "connection/stream_connection.h"

using namespace std;
using namespace tubekit::socket;
using namespace tubekit::thread;
using namespace tubekit::task;
using namespace tubekit::log;
using namespace tubekit::utility;
using namespace tubekit::server;
using namespace tubekit::hooks;
using namespace tubekit::connection;

#define TUBEKIT_OPENSSL

socket_handler::socket_handler() : m_init(false)
{
}

socket_handler::~socket_handler()
{
    if (m_epoll != nullptr)
    {
        delete m_epoll;
        m_epoll = nullptr;
    }
    if (m_server != nullptr)
    {
        delete m_server;
        m_epoll = nullptr;
    }
}

int socket_handler::attach(socket *m_socket, bool listen_send /*= false*/)
{
    if (!m_init)
    {
        return -1;
    }
    auto_lock lock(m_mutex);
    uint32_t target_events = 0;
    uint32_t now_events = m_epoll->get_events_by_fd(m_socket->m_sockfd);
    if (now_events & EPOLLOUT) // must have the next loop
    {
        return 0;
    }
    // m_socket not in epoll
    if (listen_send)
    {
        target_events = (EPOLLONESHOT | EPOLLIN | EPOLLOUT | EPOLLHUP | EPOLLERR);
    }
    else
    {
        target_events = (EPOLLONESHOT | EPOLLIN | EPOLLHUP | EPOLLERR);
    }
    int i_ret = m_epoll->add(m_socket->m_sockfd, (void *)m_socket, target_events);
    if (0 == i_ret)
    {
        return 0;
    }
    // using EPOLL_CTL_MOD
    if (!(i_ret == -1 && errno == EEXIST))
    {
        return i_ret;
    }
    return m_epoll->mod(m_socket->m_sockfd, (void *)m_socket, target_events);
}

int socket_handler::detach(socket *m_socket)
{
    if (!m_init)
    {
        return -1;
    }
    auto_lock lock(m_mutex);
    return m_epoll->del(m_socket->m_sockfd, (void *)m_socket, 0);
}

int socket_handler::remove(socket *m_socket)
{
    if (!m_init)
    {
        return -1;
    }
    int iret = detach(m_socket);
    if (0 != iret)
    {
        // LOG_ERROR("detach(m_socket) return %d", iret);
    }
    m_socket->close();
    // TODO:注意SSL_shutdown、SSL_free处理
    socket_pool.release(m_socket); // return back to socket object poll
    return iret;
}

socket *socket_handler::alloc_socket()
{
    if (!m_init)
    {
        return nullptr;
    }
    return socket_pool.allocate();
}

void socket_handler::on_tick()
{
    singleton<hooks::tick>::instance()->run();
}

bool socket_handler::init(const string &ip, int port, int max_connections, int wait_time)
{
    if (m_init)
    {
        LOG_ERROR("socket handler already init");
        return true;
    }
    m_server = new server_socket(ip, port);
    m_max_connections = max_connections;
    m_wait_time = wait_time;
    m_epoll = new event_poller(false); // false:EPOLLLT mode
    m_epoll->create(max_connections);
    m_epoll->add(m_server->m_sockfd, m_server, (EPOLLIN | EPOLLHUP | EPOLLERR)); // Register the listen socket epoll_event
    socket_pool.init(max_connections);

#ifdef TUBEKIT_OPENSSL
    // OpenSSL
    SSL_library_init();
    SSL_load_error_strings();
    SSL_CTX *m_ssl_context = SSL_CTX_new(SSLv23_server_method());
    if (!ssl_context)
    {
        return false;
    }
    int i_ret = SSL_CTX_use_certificate_file(ssl_context, "server.crt", SSL_FILETYPE_PEM);
    if (1 != i_ret)
    {
        LOG_ERROR("SSL_CTX_use_certificate_file error: %s", ERR_error_string(ERR_get_error(), nullptr));
        return false;
    }

    i_ret = SSL_CTX_use_PrivateKey_file(ssl_context, "server.key", SSL_FILETYPE_PEM);
    if (1 != i_ret)
    {
        LOG_ERROR("SSL_CTX_use_PrivateKey_file error: %s", ERR_error_string(ERR_get_error(), nullptr));
        return false;
    }
#endif

    m_init = true;
    return m_init;
}

void socket_handler::handle()
{
    if (!m_init)
    {
        LOG_ERROR("socket_handler not init,can not execute handle");
        return;
    }
    // main thread loop
    while (true)
    {
        // sys stop check
        if (singleton<tubekit::server::server>::instance()->is_stop())
        {
            singleton<tubekit::server::server>::instance()->on_stop();
            singleton<hooks::stop>::instance()->run();
#ifdef TUBEKIT_OPENSSL
            // 释放SSL
            SSL_CTX_free(m_ssl_context);
#endif
            break; // main process to exit
        }
        int num = m_epoll->wait(m_wait_time);
        on_tick();
        if (num == 0)
        {
            continue; // timeout
        }
        else if (num < 0)
        {
            if (errno == EINTR)
            {
                continue;
            }
            break;
        }

        for (int i = 0; i < num; i++) // Sockets that handle readable data
        {
            // There is a new socket connection
            if (m_server == static_cast<socket *>(m_epoll->m_events[i].data.ptr))
            {
                int socket_fd = m_server->accept(); // Gets the socket_fd for the new connection
                socket *socket_object = alloc_socket();
                if (socket_object == nullptr)
                {
                    continue;
                }
                socket_object->m_sockfd = socket_fd;
                socket_object->close_callback = nullptr;
                socket_object->set_non_blocking();
                socket_object->set_linger(false, 0);

#ifdef TUBEKIT_OPENSSL
                bool ssl_err = false;
                SSL *ssl_instance = SSL_new(m_ssl_context);
                if (!ssl_instance)
                {
                    ssl_err = true;
                    LOG_ERROR("SSL_new return NULL");
                }
                if (!ssl_err && 1 != SSL_set_fd(ssl, socket_object->m_sockfd))
                {
                    ssl_err = true;
                    LOG_ERROR("SSL_set_fd error: %s", ERR_error_string(ERR_get_error(), nullptr));
                }
                if (ssl_err)
                {
                    LOG_ERR("SSL ERR");
                    remove(socket_object);
                    continue;
                }
                // 将ssl_instance绑到socket_object上面,移动改socket_object close时对ssl_instance的处理
                socket_object->ssl_instance = ssl_instance;
#endif

                // create connection layer instance
                auto task_type = singleton<server::server>::instance()->get_task_type();
                connection::connection *p_connection = nullptr;
                {
                    switch (task_type)
                    {
                    case server::server::STREAM_TASK:
                        p_connection = new (std::nothrow) connection::stream_connection(socket_object);
                        if (p_connection == nullptr)
                        {
                            LOG_ERROR("new connection::stream_connection error");
                        }
                        break;
                    case server::server::HTTP_TASK:
                        p_connection = new (std::nothrow) connection::http_connection(socket_object);
                        if (p_connection == nullptr)
                        {
                            LOG_ERROR("new connection::http_connection error");
                        }
                        break;
                    case server::server::WEBSOCKET_TASK:
                        p_connection = new (std::nothrow) connection::websocket_connection(socket_object);
                        if (p_connection == nullptr)
                        {
                            LOG_ERROR("new connection::websocket_connection error");
                        }
                        break;
                    default:
                        break;
                    }
                }

                if (p_connection == nullptr)
                {
                    remove(socket_object);
                    continue;
                }

                bool bret = singleton<connection_mgr>::instance()->add(socket_object, p_connection);
                if (!bret)
                {
                    LOG_ERROR("singleton<connection_mgr>::instance()->add error");
                    delete p_connection;
                    remove(socket_object);
                    continue;
                }
                // TODO:第一次连接 进行一次task
                attach(socket_object, true); // listen read
            }                                // There is a new socket connection
            else                             // already connection socket has event happen
            {
                // already connection socket process
                uint32_t events = m_epoll->m_events[i].events;
                socket *socket_ptr = static_cast<socket *>(m_epoll->m_events[i].data.ptr);
                detach(socket_ptr);
                // get connection layer instance
                connection::connection *p_connection = singleton<connection_mgr>::instance()->get(socket_ptr);
                if (p_connection == nullptr)
                {
                    LOG_ERROR("exsit socket,but not exist connection");
                    remove(socket_ptr);
                    continue;
                }

                if ((events & EPOLLHUP) || (events & EPOLLERR))
                {
                    // using connection_mgr mark_close,to prevent connection already free
                    singleton<connection_mgr>::instance()->mark_close(socket_ptr);
                }

                // Different processing is triggered for different poll events
                bool recv_event = false;
                bool send_event = false;
                if ((events & EPOLLIN) || (events & EPOLLOUT)) // There is data,to be can read
                {
                    recv_event = events & EPOLLIN;
                    send_event = events & EPOLLOUT;
                }

// 如果开启SSL需要判断SSL是否已经握过手了,没有是需要accept的
#ifdef TUBEKIT_OPENSSL
                if (!already_ssl_inited)
                {
                    int ssl_status = SSL_accept(ssl_instance);
                    if (1 != ssl_status)
                    {
                        int ssl_error = SSL_get_error(ssl, ssl_status);
                        if (ssl_error == SSL_ERROR_WANT_READ || ssl_error == SSL_ERROR_WANT_WRITE)
                        {
                            // 握手需要更多数据或空间
                            attach(socket_ptr);
                            continue;
                        }
                        else
                        {
                            LOG_ERR("SSL_accept error: %d", ssl_error);
                            remove(socket_ptr);
                            continue;
                        }
                    }
                    continue;
                }
#endif

                // Decide which engine to use,such as WORKDLOW_TASK or HTTP_TASK
                auto task_type = singleton<server::server>::instance()->get_task_type();
                thread::task *new_task = nullptr;

                // create task
                switch (task_type)
                {
                case server::server::STREAM_TASK:
                    new_task = task_factory::create(socket_ptr, task_factory::STREAM_TASK);
                    if (new_task)
                    {
                        auto stream_task_ptr = (stream_task *)new_task;
                        stream_task_ptr->reason_recv = recv_event;
                        stream_task_ptr->reason_send = send_event;
                    }
                    break;
                case server::server::HTTP_TASK:
                    new_task = task_factory::create(socket_ptr, task_factory::HTTP_TASK);
                    if (new_task)
                    {
                        auto http_task_ptr = (http_task *)new_task;
                        http_task_ptr->reason_recv = recv_event;
                        http_task_ptr->reason_send = send_event;
                    }
                    break;
                case server::server::WEBSOCKET_TASK:
                    new_task = task_factory::create(socket_ptr, task_factory::WEBSOCKET_TASK);
                    if (new_task)
                    {
                        auto websocket_task_ptr = (websocket_task *)new_task;
                    }
                    break;
                default:
                    break;
                }

                // create task failed
                if (new_task == nullptr)
                {
                    LOG_ERROR("new_task is nullptr");
                    exit(EXIT_FAILURE);
                }

                // Submit the task to the queue of task_dispatcher
                singleton<worker_pool>::instance()->assign(new_task);
            }
        }
    }
}
gaowanlu commented 11 months ago

socket demo

#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <unistd.h>
#include <fcntl.h>
#include <cerrno>
#include <cstring>
#include <tubekit-log/logger.h>
#include <openssl/ssl.h>
#include <openssl/err.h>

#include "socket/socket.h"
#include "utility/singleton.h"

using namespace tubekit::socket;
using namespace tubekit::utility;
using namespace tubekit::log;

socket::socket() : m_sockfd(0)
{
}

socket::socket(const string &ip, int port) : m_ip(ip), m_port(port), m_sockfd(0)
{
}

socket::~socket()
{
    // TODO:shutdown处理非阻塞IO需要特殊处理 返回1其实才代表成功,但其实无所谓了,反正要断开处理
    SSL_shutdown(m_ssl_instance);
    SSL_free(m_ssl_instance);
    close();
}

bool socket::bind(const string &ip, int port)
{
    struct sockaddr_in sockaddr;
    memset(&sockaddr, 0, sizeof(sockaddr)); // init 0
    sockaddr.sin_family = AF_INET;          // IPV4
    if (ip != "")
    {
        sockaddr.sin_addr.s_addr = inet_addr(ip.c_str());
    }
    else
    {
        sockaddr.sin_addr.s_addr = htonl(INADDR_ANY); // 0.0.0.0
    }
    // htonl htons : change to net byte sequeue from host byte
    sockaddr.sin_port = htons(port);
    if (::bind(m_sockfd, (struct sockaddr *)&sockaddr, sizeof(sockaddr)) < 0)
    {
        LOG_ERROR("socket bind error: errno=%d errstr=%s", errno, strerror(errno));
        return false;
    }
    return true;
}

bool socket::listen(int backlog)
{
    // backlog: queue of pending connections
    if (::listen(m_sockfd, backlog) < 0)
    {
        LOG_ERROR("socket listen error: errno=%d errstr=%s", errno, strerror(errno));
        return false;
    }
    return true;
}

bool socket::connect(const string &ip, int port)
{
    struct sockaddr_in sockaddr;
    memset(&sockaddr, 0, sizeof(sockaddr));
    sockaddr.sin_family = AF_INET;
    sockaddr.sin_addr.s_addr = inet_addr(ip.c_str());
    sockaddr.sin_port = htons(port);
    if (::connect(m_sockfd, (struct sockaddr *)&sockaddr, sizeof(sockaddr)) < 0)
    {
        LOG_ERROR("socket connect error: errno=%d errstr=%s", errno, strerror(errno));
        return false;
    }
    return true;
}

bool socket::close()
{
    if (close_callback)
    {
        close_callback();
    }
    if (m_sockfd > 0)
    {
        ::close(m_sockfd);
        m_sockfd = 0;
    }
    close_callback = nullptr;
    m_ip.clear();
    m_port = 0;
    return true;
}

int socket::accept()
{
    int sockfd = ::accept(m_sockfd, NULL, NULL);
    if (sockfd < 0)
    {
        LOG_ERROR("accept call error: errno=%d errstr=%s", errno, strerror(errno));
        sockfd = -1;
    }
    return sockfd;
}

int socket::recv(char *buf, size_t len, int oper_errno)
{
#ifndef TUBEKIT_OPENSSL
    // read len bytes from m_sockfd
    int result = ::recv(m_sockfd, buf, len, 0);
    if (result == -1)
    {
        oper_errno = errno;
    }
    return result;
#elif
    int bytes_received = SSL_read(m_ssl_instance, buf, len);
    if (bytes_received > 0)
    {
        return bytes_received;
    }
    else if (bytes_received == 0)
    {
        return bytes_received;
    }
    else
    {
        int ssl_error = SSL_get_error(ssl, bytes_received);
        if (ssl_error != SSL_ERROR_WANT_READ || ssl_error != SSL_ERROR_WANT_WRITE)
        {
            oper_errno = EAGAIN;
        }
        else
        {
            oper_errno = ssl_error;
            LOG_ERR("");
        }
        return bytes_received;
    }
#endif
}

int socket::send(const char *buf, size_t len, int oper_errno)
{
#ifndef TUBEKIT_OPENSSL
    // write data to m_sockfd
    int result = ::send(m_sockfd, buf, len, 0);
    if (result == -1)
    {
        oper_errno = errno;
    }
    return result;
#elif
    int bytes_written = SSL_write(m_ssl_instance, buf, len);
    if (bytes_written > 0)
    {
        return bytes_written;
    }
    else if (bytes_written == 0)
    {
        return bytes_written;
    }
    else
    {
        int ssl_error = SSL_get_error(ssl, bytes_written);
        if (ssl_error != SSL_ERROR_WANT_READ || ssl_error != SSL_ERROR_WANT_WRITE)
        {
            oper_errno = EAGAIN;
        }
        else
        {
            oper_errno = ssl_error;
            LOG_ERR("");
        }
        return bytes_received;
    }
#endif
}

bool socket::set_non_blocking()
{
    int flags = fcntl(m_sockfd, F_GETFL, 0);
    if (flags < 0)
    {
        LOG_ERROR("socket::set_non_blocking(F_GETFL,O_NONBLOCK) errno=%d errstr=%s", errno, strerror(errno));
        return false;
    }
    flags |= O_NONBLOCK; // setting nonblock
    if (fcntl(m_sockfd, F_SETFL, flags) < 0)
    {
        LOG_ERROR("socket::set_non_blocking(F_SETFL,O_NONBLOCK) errno=%d errstr=%s", errno, strerror(errno));
        return false;
    }
    return true;
}

bool socket::set_blocking()
{
    int flags = fcntl(m_sockfd, F_GETFL, 0);
    if (flags < 0)
    {
        LOG_ERROR("socket::set_blocking() errno=%d errstr=%s", errno, strerror(errno));
        return false;
    }
    flags &= ~O_NONBLOCK; // setting nonblock
    if (fcntl(m_sockfd, F_SETFL, flags) < 0)
    {
        LOG_ERROR("socket::set_blocking() errno=%d errstr=%s", errno, strerror(errno));
        return false;
    }
    return true;
}

bool socket::set_send_buffer(size_t size)
{
    size_t buffer_size = size;
    if (setsockopt(m_sockfd, SOL_SOCKET, SO_SNDBUF, &buffer_size, sizeof(buffer_size)) < 0)
    {
        LOG_ERROR("socket set send buffer error: errno=%d errstr=%s", errno, strerror(errno));
        return false;
    }
    return true;
}

bool socket::set_recv_buffer(size_t size)
{
    int buffer_size = size;
    if (setsockopt(m_sockfd, SOL_SOCKET, SO_RCVBUF, &buffer_size, sizeof(buffer_size)) < 0)
    {
        LOG_ERROR("socket set recv buffer errno=%d errstr=%s", errno, strerror(errno));
        return false;
    }
    return true;
}

bool socket::set_linger(bool active, size_t seconds)
{
    // l_onoff = 0; l_linger ignore
    // close()

    // l_onoff != 0;
    // l_linger = 0;
    // close()

    // l_onoff != 0;
    // l_linger > 0;
    // close()
    struct linger l;
    memset(&l, 0, sizeof(l));
    if (active)
        l.l_onoff = 1;
    else
        l.l_onoff = 0;
    l.l_linger = seconds;
    if (setsockopt(m_sockfd, SOL_SOCKET, SO_LINGER, &l, sizeof(l)) < 0)
    {
        LOG_ERROR("socket set linger error errno=%d errstr=%s", errno, strerror(errno));
        return false;
    }
    return true;
}

bool socket::set_keep_alive()
{
    int flag = 1;
    if (setsockopt(m_sockfd, SOL_SOCKET, SO_KEEPALIVE, &flag, sizeof(flag)) < 0)
    {
        LOG_ERROR("socket set sock keep alive error: errno=%d errstr=%s", errno, strerror(errno));
        return false;
    }
    return true;
}

bool socket::set_reuse_addr()
{
    int flag = 1;
    if (setsockopt(m_sockfd, SOL_SOCKET, SO_REUSEADDR, &flag, sizeof(flag)) < 0)
    {
        LOG_ERROR("socket set sock reuser addr error: errno=%d errstr=%s", errno, strerror(errno));
        return false;
    }
    return true;
}

bool socket::set_reuse_port()
{
    int flag = 1;
    if (setsockopt(m_sockfd, SOL_SOCKET, SO_REUSEPORT, &flag, sizeof(flag)) < 0)
    {
        LOG_ERROR("socket set sock reuser port error: errno=%d errstr=%s", errno, strerror(errno));
        return false;
    }
    return true;
}

int socket::create_tcp_socket()
{
    int fd = ::socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
    if (fd < 0)
    {
        LOG_ERROR("create tcp socket error: errno=%d errstr=%s", errno, strerror(errno));
        return fd;
    }
    return fd;
}

int socket::get_fd()
{
    return this->m_sockfd;
}