pistacheio / pistache

A high-performance REST toolkit written in C++
https://pistacheio.github.io/pistache/
Apache License 2.0
3.18k stars 698 forks source link

[Question][Suggestion] how to change the peer->getData/Http::Handler onInput in runtime for WebSockets implmentation #1082

Open Fabio3rs opened 2 years ago

Fabio3rs commented 2 years ago

Hi. I was testing how to implement WebSockets using Pistache. The flow is making the handshake in HTTP/1.1, then switching the raw socket input data treating to a WebSockets implementation. And I tried somethings like: to inherit the Http::Handler or the Rest::Private::RouterHandler, but the methods onConnection and onInput being private blocks from parent calling inside the custom class method.

Using a proxy method and receiving and existing object created by Router::handler(), casting to the base Tcp::Handler can work but needs to adjust the sizes that the Http::Handler holds because the Transport class write to the proxy class before calling onIput, something like this maybe: *original = *this; ***In this case treating the original as Http::Handler, but this method is not thread safe and adding a mutex maybe impact the performance.

I was thinking about change the parser or change the onInput to perform polyformally but I can't change the Parser of a already connected Tcp::Peer.

I think I can not exactly implements webosckets outside Pistache's codebase or maybe changing some methods visibility in the PIstache's code.

One of my tests using inheritance and calling the parent, but I change the private to protected on the Http::Handler for this to work.

Handler set

endpoint->setHandler(
        std::make_shared<RouterHandlerProxy>(router));

Route handshake minimal code to study the handler question

void CWebSocketController::ws_route(const Pistache::Rest::Request &request,
                                    Pistache::Http::ResponseWriter response) {
    response.headers().clear();
    response.headers().addRaw(Header::Raw{"Upgrade", "websocket"});
    response.headers().addRaw(Header::Raw{"Connection", "Upgrade"});
    response.headers().addRaw(Header::Raw{
        "Sec-WebSocket-Accept",
        computeAccept(request.headers().getRaw("Sec-WebSocket-Key").value())});

    putOnWire(response);

    auto peer = response.getPeer();

    fdatasync(peer->fd());

    auto wsHandler = std::make_shared<WebSocketHandler>();
    wsHandler->peer = response.getPeer();

// websocket frame received callback
    wsHandler->onMessage = [](const WebSocketHandler::frame &frame) {
        std::cout << frame.flags << std::endl;
        std::cout << frame.size << std::endl;
        std::cout << frame.payload << std::endl;
    };

    peer->putData("__WEBSOCKETHANDLER", wsHandler);
    threads.push(asyncws, wsHandler); // a thread to send data to the websocket
}
#pragma once
#include "stdafx.hpp"
#include <memory>
#include <pistache/router.h>

class RouterHandler : public Pistache::Rest::Private::RouterHandler {

  public:
    /*void onRequest(const Pistache::Http::Request &request,
                   Pistache::Http::ResponseWriter response) override;*/

    void
    onConnection(const std::shared_ptr<Pistache::Tcp::Peer> &peer) override;
    void
    onDisconnection(const std::shared_ptr<Pistache::Tcp::Peer> &peer) override;
    void onInput(const char *buffer, size_t len,
                 const std::shared_ptr<Pistache::Tcp::Peer> &peer) override;

    RouterHandler(Pistache::Rest::Router &router)
        : Pistache::Rest::Private::RouterHandler(router) {}

    std::shared_ptr<Pistache::Tcp::Handler> clone() const override {
        return std::make_shared<RouterHandler>(*this);
    }

    RouterHandler(const RouterHandler &) = default;
    RouterHandler(RouterHandler &&) = default;
    RouterHandler &operator=(const RouterHandler &) = default;
    RouterHandler &operator=(RouterHandler &&) = default;
    ~RouterHandler() override;
};
#include "RouterHandler.hpp"
#include "WebSocketHandler.hpp"
#include <cstddef>
#include <cstdint>
#include <memory>
#include <pistache/peer.h>

void RouterHandler::onConnection(
    const std::shared_ptr<Pistache::Tcp::Peer> &peer) {
    std::cout << __func__ << ": " << peer->fd() << std::endl;
    Pistache::Rest::Private::RouterHandler::onConnection(peer);
}

void RouterHandler::onDisconnection(
    const std::shared_ptr<Pistache::Tcp::Peer> &peer) {
    if (auto WSHandler = peer->tryGetData("__WEBSOCKETHANDLER")) {
        std::cout << "Websocket disconnected\n";
    }

    Pistache::Rest::Private::RouterHandler::onDisconnection(peer);
}

void RouterHandler::onInput(
    const char *buffer, size_t len,
    const std::shared_ptr<Pistache::Tcp::Peer> &peer) {
    if (auto WSHandler = peer->tryGetData("__WEBSOCKETHANDLER")) {
        auto handler = std::static_pointer_cast<WebSocketHandler>(WSHandler);
        handler->onInput(buffer, len, peer);
    } else {
        Pistache::Rest::Private::RouterHandler::onInput(buffer, len, peer);
    }
}

RouterHandler::~RouterHandler() = default;

Incomplete WebSocket protocol reading:

#pragma once
#include "stdafx.hpp"
#include <array>
#include <atomic>
#include <cstddef>
#include <cstdint>
#include <functional>
#include <memory>
#include <string>
#include <utility>

class WebSocketHandler {
  public:
    struct frame {
        int32_t flags{}, lenByte{};
        uint64_t size{};
        std::array<int8_t, 4> mask{};
        int32_t readingState{};

        std::string payload{};
        size_t bufferPos{};
        std::array<char, 16> bufferTmp{};

        const std::shared_ptr<Pistache::Tcp::Peer> *peer{};

        bool useMask{};

        std::pair<size_t, bool> receiveData(const char *buffer, size_t lenraw);
    };

    std::atomic<bool> disconnected{false};

  protected:
    frame frameInst;

  public:
    std::function<void(const frame&)> onMessage;
    std::weak_ptr<Pistache::Tcp::Peer> peer;

    void onInput(const char *buffer, size_t len,
                 const std::shared_ptr<Pistache::Tcp::Peer> &peer);
};
#include "WebSocketHandler.hpp"
#include <cstddef>
#include <pistache/peer.h>
#include <tuple>

std::pair<size_t, bool> WebSocketHandler::frame::receiveData(const char *buffer,
                                                             size_t lenraw) {
    const uint8_t *bytesRaw = reinterpret_cast<const uint8_t *>(buffer);
    size_t inBufferOffset = 0;

    bool done = false;

    while (inBufferOffset < lenraw) {
        switch (readingState) {
        case 0:
            flags = bytesRaw[inBufferOffset++];
            ++readingState;
            break;

        case 1:
            lenByte = bytesRaw[inBufferOffset++];
            useMask = (lenByte & 0x80) != 0;
            size = lenByte & 0x7F;
            ++readingState;
            bufferPos = 0;
            break;

        case 2:
            if (size == 126 || size == 127) {
                bufferTmp[bufferPos++] = buffer[inBufferOffset++];

                if (bufferPos == 8 && size == 127) {
                    size = be64toh(
                        *reinterpret_cast<uint64_t *>(bufferTmp.data()));
                    ++readingState;
                    bufferPos = 0;
                } else if (bufferPos == 2 && size == 126) {
                    size = be16toh(
                        *reinterpret_cast<uint16_t *>(bufferTmp.data()));
                    ++readingState;
                    bufferPos = 0;
                }
            } else {
                ++readingState;
                bufferPos = 0;
            }
            break;

        case 3:
            if (useMask) {
                bufferTmp[bufferPos++] = buffer[inBufferOffset++];

                if (bufferPos == 4) {
                    std::copy_n(bufferTmp.begin(), 4, mask.begin());
                    ++readingState;
                    bufferPos = 0;
                }
            } else {
                ++readingState;
            }
            break;
        case 4:
            payload.resize(size);
            ++readingState;
            bufferPos = 0;
            break;

        default:
            payload[bufferPos++] = buffer[inBufferOffset++];

            if (bufferPos == size) {
                done = true;
            }
            break;
        }
    }

    if (!done) {
        return {inBufferOffset, false};
    }

    readingState = 0;
    bufferPos = 0;

    for (size_t i = 0; i < payload.size(); i++) {
        payload[i] ^= mask[i % 4];
    }

    return {inBufferOffset, true};
}

void WebSocketHandler::onInput(
    const char *buffer, size_t lenraw,
    const std::shared_ptr<Pistache::Tcp::Peer> &inpeer) {

    size_t inBufferOffset = 0;

    do {
        bool done = false;
        const char *bufferit = buffer + inBufferOffset;
        size_t currentLen = lenraw - inBufferOffset;

        std::tie(inBufferOffset, done) =
            frameInst.receiveData(bufferit, currentLen);

        if (!done) {
            return;
        }

        frameInst.peer = &inpeer;

        if (onMessage) {
            onMessage(frameInst);
        }

        frameInst.flags = 0;
        frameInst.lenByte = 0;
    } while (inBufferOffset < lenraw);
}