tokio-rs / axum

Ergonomic and modular web framework built with Tokio, Tower, and Hyper
18.31k stars 1.03k forks source link

Update gRPC multixplex example for axum 0.7 and hyper 1 #2736

Open FlyinPancake opened 4 months ago

FlyinPancake commented 4 months ago

Feature Request

Update example rest-grpc-multiplex

Motivation

Currently, the example code in https://github.com/tokio-rs/axum/tree/main/examples/rest-grpc-multiplex is commented out and marked as TODO.

The code in the example is not really straight forward for beginners and can be a good starting point for learning exercises.

Proposal

Update the code to be compatible with axum 0.7 and hyper 1.

takkuumi commented 4 months ago

I'll try to make it work this week.

FlyinPancake commented 3 months ago

Has there been any progress on this?

jplatte commented 3 months ago

There has been progress on https://github.com/hyperium/tonic/pull/1670, and once that is merged and released, this sounds be pretty simple to do.

abhiaagarwal commented 2 months ago

The tonic PR has been merged — if anyone isn't working on updating the example I'll give it a shot!

mladedav commented 2 months ago

I think you can go ahead, just note that while the PR has been merged, new version is not yet released and there might be a few more breaking changes.

But feel free to go ahead, I wouldn't expect any other large changes that should influence the example, just so you know.

abhiaagarwal commented 2 months ago

@mladedav I'll probably just open a draft pr with a git dependency, then.

abhiaagarwal commented 2 months ago

I've been working on this on and off and haven't been able to get this working, but I'll share a few of my findings here.

My approach was to use tower::steer to drive the routing between the axum/grpc services. However, due to the removal of the common types in Hyper 1.0, axum and tonic both have different request and response types. I believe the "right" way to do this would be to have a middleware that wraps the tonic service to return axum types since tower::steer's type signature needs to match axum's. Boxing the service doesn't work.

I've still been hacking away on it, but it's been challenging, to say the least. If anyone wants to give it a shot, feel free.

rtkay123 commented 2 months ago

Heads up, a new version of tonic, 0.12.0 has been released: https://github.com/hyperium/tonic/pull/1740

brocaar commented 1 month ago

I have not been able to solve my use-case with the PR from @abhiaagarwal, it seems like when using into_router() that the layer configuration is lost when callinglayer before the into_router method, but calling it afterwards results in type errors (in my use-case at least).

I've still been hacking away on it, but it's been challenging, to say the least. If anyone wants to give it a shot, feel free.

Also I have been struggling with this for more time that I hoped, but the code below seems to work. The approach is different as it adds the axum Router service as layer to the gRPC server using a multiplexer layer. Requests != grpc are handled by the given service (the layer returns early). Snippet:

    let web = Router::new()
        .route("/test", get(|| async { "Hello, World!" }))
        .into_service()
        .map_response(|r| r.map(tonic::body::boxed));

    let grpc = Server::builder()
        .layer(GrpcMultiplexLayer::new(web))
        .add_service(
            ...
        );

    let addr = "[::1]:50051".parse().unwrap();
    grpc.serve(addr).await.unwrap();

Full example:

use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};

use axum::{routing::get, Router};
use futures::ready;
use http::{header::CONTENT_TYPE, Request, Response};
use http_body::Body;
use pin_project::pin_project;
use tonic::transport::Server;
use tonic_reflection::server::Builder as TonicReflectionBuilder;
use tower::{Layer, Service, ServiceExt};

type BoxError = Box<dyn std::error::Error + Send + Sync>;

#[pin_project(project = GrpcMultiplexFutureEnumProj)]
enum GrpcMultiplexFutureEnum<FS, FO> {
    Grpc {
        #[pin]
        future: FS,
    },
    Other {
        #[pin]
        future: FO,
    },
}

#[pin_project]
pub struct GrpcMultiplexFuture<FS, FO> {
    #[pin]
    future: GrpcMultiplexFutureEnum<FS, FO>,
}

impl<ResBody, FS, FO, ES, EO> Future for GrpcMultiplexFuture<FS, FO>
where
    ResBody: Body,
    FS: Future<Output = Result<Response<ResBody>, ES>>,
    FO: Future<Output = Result<Response<ResBody>, EO>>,
    ES: Into<BoxError> + Send,
    EO: Into<BoxError> + Send,
{
    type Output = Result<Response<ResBody>, Box<dyn std::error::Error + Send + Sync + 'static>>;

    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        let this = self.project();
        match this.future.project() {
            GrpcMultiplexFutureEnumProj::Grpc { future } => future.poll(cx).map_err(Into::into),
            GrpcMultiplexFutureEnumProj::Other { future } => future.poll(cx).map_err(Into::into),
        }
    }
}

#[derive(Debug, Clone)]
pub struct GrpcMultiplexService<S, O> {
    grpc: S,
    other: O,
    grpc_ready: bool,
    other_ready: bool,
}

impl<ReqBody, ResBody, S, O> Service<Request<ReqBody>> for GrpcMultiplexService<S, O>
where
    ResBody: Body,
    S: Service<Request<ReqBody>, Response = Response<ResBody>>,
    O: Service<Request<ReqBody>, Response = Response<ResBody>>,
    S::Error: Into<BoxError> + Send,
    O::Error: Into<BoxError> + Send,
{
    type Response = S::Response;
    type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
    type Future = GrpcMultiplexFuture<S::Future, O::Future>;

    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        loop {
            match (self.grpc_ready, self.other_ready) {
                (true, true) => {
                    return Ok(()).into();
                }
                (false, _) => {
                    ready!(self.grpc.poll_ready(cx)).map_err(Into::into)?;
                    self.grpc_ready = true;
                }
                (_, false) => {
                    ready!(self.other.poll_ready(cx)).map_err(Into::into)?;
                    self.other_ready = true;
                }
            }
        }
    }

    fn call(&mut self, request: Request<ReqBody>) -> Self::Future {
        assert!(self.grpc_ready);
        assert!(self.other_ready);

        if is_grpc_request(&request) {
            GrpcMultiplexFuture {
                future: GrpcMultiplexFutureEnum::Grpc {
                    future: self.grpc.call(request),
                },
            }
        } else {
            GrpcMultiplexFuture {
                future: GrpcMultiplexFutureEnum::Other {
                    future: self.other.call(request),
                },
            }
        }
    }
}

#[derive(Debug, Clone)]
pub struct GrpcMultiplexLayer<O> {
    other: O,
}

impl<O> GrpcMultiplexLayer<O> {
    pub fn new(other: O) -> Self {
        Self { other }
    }
}

impl<S, O> Layer<S> for GrpcMultiplexLayer<O>
where
    O: Clone,
{
    type Service = GrpcMultiplexService<S, O>;

    fn layer(&self, grpc: S) -> Self::Service {
        GrpcMultiplexService {
            grpc,
            other: self.other.clone(),
            grpc_ready: false,
            other_ready: false,
        }
    }
}

fn is_grpc_request<B>(req: &Request<B>) -> bool {
    req.headers()
        .get(CONTENT_TYPE)
        .map(|content_type| content_type.as_bytes())
        .filter(|content_type| content_type.starts_with(b"application/grpc"))
        .is_some()
}

#[tokio::main]
async fn main() {
    let web = Router::new()
        .route("/test", get(|| async { "Hello, World!" }))
        .into_service()
        .map_response(|r| r.map(tonic::body::boxed));

    let grpc = Server::builder()
        .layer(GrpcMultiplexLayer::new(web))
        .add_service(
            ...
        );

    let addr = "[::1]:50051".parse().unwrap();

    grpc.serve(addr).await.unwrap();
}

Requirements:

  tokio = { version = "1.38", features = ["macros", "rt-multi-thread"] }
  tonic = "0.12"
  tower = { version = "0.4", features = ["steer"] }
  tonic-reflection = "0.12"
  axum = "0.7"
  futures = "0.3"
  hyper = "1.4"
  hyper-util = "0.1"
  http-body-util = "0.1"
  anyhow = "1.0"
  http = "1.1"
  pin-project = "1.1"
  http-body = "1.0"
infiniteregrets commented 1 month ago

@brocaar thanks so much for providing your implementation, it works just fine! I was trying to do something similar and finding your code saved me hours!

spence commented 3 weeks ago

thanks @brocaar

i also needed the remote IP, but i wasn't able to get into_make_service_with_connect_info working. however, tonic sets the remote IP as a request extension, and i made an axum extract to simplify it:

// .route("/ip", get(get_ip))
async fn get_ip(TonicRemoteAddr(addr): TonicRemoteAddr) -> Response {
  Json(json!({ "ip": addr.ip() }))
}

#[derive(Debug, Clone, Copy)]
pub struct TonicRemoteAddr(pub SocketAddr);

#[axum::async_trait]
impl<S: Send + Sync> FromRequest<S> for TonicRemoteAddr {
  type Rejection = std::convert::Infallible;

  async fn from_request(req: Request, _state: &S) -> Result<Self, Self::Rejection> {
    let result = req
      .extensions()
      .get::<TcpConnectInfo>()
      .and_then(|connect_info| connect_info.remote_addr())
      .map(TonicRemoteAddr);

    Ok(result.unwrap())
  }
}