facebook / wangle

Wangle is a framework providing a set of common client/server abstractions for building services in a consistent, modular, and composable way.
Apache License 2.0
3.04k stars 535 forks source link

Acceptor #216

Open SteveSelva opened 1 year ago

SteveSelva commented 1 year ago

In wangle/server/Acceptor.cpp, for SSL Connections, the Handshake is done by Fizz. Since I needed a callback at ClientHello message to fetch the server name from SNI extension in ClientHello, I implemented the handshake using folly::AsyncSSLSocket which has the ClientHello callback function. Does this cause any issues in the functioning of Server?

Example.cpp (My Implementation)

void Example::onNewConnection(
        folly::AsyncTransport::UniquePtr sock,
        const SocketAddress* address,
        const std::string& nextProtocolName,
        SecureTransportType secureTransportType,
        const wangle::TransportInfo& tinfo) {
        LOG(INFO) << "Accepting Connection from Browser";
        auto& filter = serverOptions_.newConnectionFilter;
        if (filter) {
            try {
                filter(sock.get(), address, nextProtocolName, secureTransportType, tinfo);
            }
            catch (const std::exception& e) {
                sock->closeWithReset();
                LOG(INFO) << "Exception filtering new socket: " << folly::exceptionStr(e);
                return;
            }
        }
        LOG(INFO) <<"Security Protocol : "<< sock->getSecurityProtocol();

        HTTPSessionAcceptor::onNewConnection(std::move(sock), address, nextProtocolName, secureTransportType, tinfo);
    }

void Example::processEstablishedConnection(int fd, const folly::SocketAddress& clientAddr, std::chrono::steady_clock::time_point acceptTime, wangle::TransportInfo& tinfo) noexcept {
        if(!isSSL)
        {
            LOG(INFO) << "HTTP Connection";
            tinfo.secure = false;
            tinfo.acceptTime = acceptTime;
            folly::AsyncSocket::UniquePtr sock(makeNewAsyncSocket(base_, fd, &clientAddr));
            tinfo.tfoSucceded = sock->getTFOSucceded();
            plaintextConnectionReady(std::move(sock), clientAddr, tinfo);
        }
        else {
            LOG(INFO) << "HTTPS Connection";
            folly::EventBase* evb = getEventBase(); 
            folly::NetworkSocket ns = folly::NetworkSocket::fromFd(fd);
            this->getSSLContextManager()->getDefaultSSLCtx()->addClientHelloCallback(clientHelloCallback);
            folly::AsyncSSLSocket::UniquePtr sslSock(new folly::AsyncSSLSocket(this->getSSLContextManager()->getDefaultSSLCtx(), evb, ns, true, false, &clientAddr));
            sslSock->setSupportedApplicationProtocols({ "h2", "http/1.1"});
            sslSock->enableClientHelloParsing();
            sslSock->sslAccept(new HandshakeCB(), std::chrono::minutes(1), folly::SSLContext::SSLVerifyPeerEnum::USE_CTX);
            Example::onNewConnection(std::move(sslSock), &clientAddr, {}, wangle::SecureTransportType::TLS, tinfo);
        }
    }

Acceptor.cpp (Wangle's Implementation)


void Acceptor::processEstablishedConnection(
    int fd,
    const SocketAddress& clientAddr,
    std::chrono::steady_clock::time_point acceptTime,
    TransportInfo& tinfo) noexcept {
  bool shouldDoSSL = false;
  if (accConfig_.isSSL()) {
    CHECK(sslCtxManager_);
    shouldDoSSL = sslCtxManager_->getDefaultSSLCtx() != nullptr;
  }
  if (shouldDoSSL) {
    AsyncSSLSocket::UniquePtr sslSock(makeNewAsyncSSLSocket(
        sslCtxManager_->getDefaultSSLCtx(), base_, fd, &clientAddr));
    ++numPendingSSLConns_;
    if (numPendingSSLConns_ > accConfig_.maxConcurrentSSLHandshakes) {
      VLOG(2) << "dropped SSL handshake on " << accConfig_.name
              << " too many handshakes in progress";
      auto error = SSLErrorEnum::DROPPED;
      auto latency = std::chrono::milliseconds(0);
      auto ex = folly::make_exception_wrapper<SSLException>(
          error, latency, sslSock->getRawBytesReceived());
      updateSSLStats(sslSock.get(), latency, error, ex);
      sslConnectionError(ex);
      return;
    }

    tinfo.tfoSucceded = sslSock->getTFOSucceded();
    for (const auto& cb : observerList_.getAll()) {
      cb->accept(sslSock.get());
    }
    startHandshakeManager(
        std::move(sslSock), this, clientAddr, acceptTime, tinfo);
  } else {
    tinfo.secure = false;
    tinfo.acceptTime = acceptTime;
    AsyncSocket::UniquePtr sock(makeNewAsyncSocket(base_, fd, &clientAddr));
    tinfo.tfoSucceded = sock->getTFOSucceded();
    for (const auto& cb : observerList_.getAll()) {
      cb->accept(sock.get());
    }
    plaintextConnectionReady(std::move(sock), clientAddr, tinfo);
  }
}

void Acceptor::startHandshakeManager(
    AsyncSSLSocket::UniquePtr sslSock,
    Acceptor*,
    const SocketAddress& clientAddr,
    std::chrono::steady_clock::time_point acceptTime,
    TransportInfo& tinfo) noexcept {
  auto manager = securityProtocolCtxManager_.getHandshakeManager(
      this, clientAddr, acceptTime, tinfo);
  manager->start(std::move(sslSock));
}

void Acceptor::connectionReady(
    AsyncTransport::UniquePtr sock,
    const SocketAddress& clientAddr,
    const string& nextProtocolName,
    SecureTransportType secureTransportType,
    TransportInfo& tinfo) {
  if (state_ >= State::kDraining) {
    return;
  }

  // Limit the number of reads from tclshe socket per poll loop iteration,
  // both to keep memory usage under control and to prevent one fast-
  // writing client from starving other connections.
  auto asyncSocket = sock->getUnderlyingTransport<AsyncSocket>();
  asyncSocket->setMaxReadsPerEvent(accConfig_.socketMaxReadsPerEvent);
  tinfo.initWithSocket(asyncSocket);
  tinfo.appProtocol = std::make_shared<std::string>(nextProtocolName);

  for (const auto& cb : observerList_.getAll()) {
    cb->ready(sock.get());
  }

  folly::AsyncTransport::UniquePtr transformed =
      transformTransport(std::move(sock));

  onNewConnection(
      std::move(transformed),
      &clientAddr,
      nextProtocolName,
      secureTransportType,
      tinfo);
}

void Acceptor::plaintextConnectionReady(
    AsyncSocket::UniquePtr sock,
    const SocketAddress& clientAddr,
    TransportInfo& tinfo) {
  connectionReady(
      std::move(sock), clientAddr, {}, SecureTransportType::NONE, tinfo);
}

void Acceptor::sslConnectionReady(
    AsyncTransport::UniquePtr sock,
    const SocketAddress& clientAddr,
    const string& nextProtocol,
    SecureTransportType secureTransportType,
    TransportInfo& tinfo) {
  CHECK(numPendingSSLConns_ > 0);
  --numPendingSSLConns_;
  connectionReady(
      std::move(sock), clientAddr, nextProtocol, secureTransportType, tinfo);
  if (state_ == State::kDraining) {
    checkDrained();
  }
}