grpc / grpc-java

The Java gRPC implementation. HTTP/2 based RPC
https://grpc.io/docs/languages/java/
Apache License 2.0
11.46k stars 3.85k forks source link

Scala: Context propagation and async interceptors #2984

Open ostronom opened 7 years ago

ostronom commented 7 years ago

I want to pass some values from interceptor to rpc handler. I've read that this can be done with contexts. But the problem is, that my interceptor is asynchronous, i.e. it "waits" for the future to resolve before calling next listener. The context is lost in this situation. My code is in Scala:

case class AsyncContextawareInterceptor[A](
    f: Metadata ⇒ Future[Either[Status, (Context.Key[A], A)]]
)(implicit val system: ActorSystem)
    extends ServerInterceptor
    with AnyLogging {
  import system.dispatcher

  sealed trait Msg
  case object HalfClose extends Msg
  case object Cancel extends Msg
  case object Complete extends Msg
  case object Ready extends Msg
  case class Message[T](msg: T) extends Msg

  override def interceptCall[ReqT, RespT](call: ServerCall[ReqT, RespT],
                                          headers: Metadata,
                                          next: ServerCallHandler[ReqT, RespT]): ServerCall.Listener[ReqT] =
    new ServerCall.Listener[ReqT] {
      private val stash = new java.util.concurrent.ConcurrentLinkedQueue[Msg]()
      private var interceptor: Option[ServerCall.Listener[ReqT]] = None

      private def enqueueAndProcess(msg: Msg) =
        if (interceptor.isDefined) processMessage(msg) else stash.add(msg)

      private def processMessage(msg: Msg) = msg match {
        case HalfClose ⇒ interceptor.foreach(_.onHalfClose)
        case Cancel ⇒ interceptor.foreach(_.onCancel)
        case Complete ⇒ interceptor.foreach(_.onComplete)
        case Ready ⇒ interceptor.foreach(_.onReady)
        case Message(msg: ReqT @unchecked) ⇒ interceptor.foreach(_.onMessage(msg))
      }

      private def processMessages() = while (!stash.isEmpty) {
        Option(stash.poll).foreach(processMessage)
      }

      override def onHalfClose(): Unit = enqueueAndProcess(HalfClose)

      override def onCancel(): Unit = enqueueAndProcess(Cancel)

      override def onComplete(): Unit = enqueueAndProcess(Complete)

      override def onReady(): Unit = enqueueAndProcess(Ready)

      override def onMessage(message: ReqT): Unit = enqueueAndProcess(Message(message))

      f(headers).map {
        case Right((k, v)) ⇒
          val context = Context.current.withValue(k, v)
          interceptor = Some(Contexts.interceptCall(context, call, headers, next))
          processMessages()
        case Left(status) ⇒ call.close(status, new Metadata())
      }.recover {
        case t: Throwable ⇒
          log.error(t, "AsyncContextawareInterceptor future failed")
          call.close(Status.fromThrowable(t), new Metadata())
      }
    }
}

object AuthInterceptor {
  val BOTID_CONTEXT_KEY: Context.Key[Int] = Context.key[Int]("botId")
  val TOKEN_HEADER_KEY: Metadata.Key[String] = Metadata.Key.of[String]("token", Metadata.ASCII_STRING_MARSHALLER)

  def authInterceptor(resolver: String ⇒ Future[Option[Int]])(implicit system: ActorSystem): ServerInterceptor =
    AsyncContextawareInterceptor { metadata ⇒
      import system.dispatcher
      (for {
        token ← OptionT.fromOption[Future](Option(metadata.get(TOKEN_HEADER_KEY)))
        botId ← OptionT(resolver(token))
      } yield botId).value.map {
        case Some(id) ⇒ Right(BOTID_CONTEXT_KEY → id)
        case None ⇒ Left(Status.PERMISSION_DENIED)
      }
    }
}

The problem is that BOTID_CONTEXT_KEY.get is null in RPC handler, even when the future was resolved and the not-null value was set.

ejona86 commented 7 years ago

I saw the question at http://stackoverflow.com/q/43805316/4690866 , but I don't think we have enough Scala experience to answer. Context uses ThreadLocal to propagate state. IIRC Scala has some feature that can dispatch work to other threads, which could break the ThreadLocal. You either need to teach Scala how to copy the ThreadLocal or maybe use Scala's equivalent of InheritableThreadLocal that auto-propagates to other threads.

Note that InheritableThreadLocal itself isn't a great solution for Java, but the equivalent in Scala may work better. If you reference some docs I could maybe help determine if a solution would work well.

ostronom commented 7 years ago

Thank you for your reply. I will investigate ThreadLocal solution. Right now, i've made this work using two interceptors: first is the one in the starting post, which propagates value to next listener via headers, second one is synchronous and it moves the value from headers to context.

I want to try to implement explicit context, that will travel with request as one of the listener/handler parameters. Maybe it isn't too hard and will solve this kind of problems.

anhldbk commented 7 years ago

@ejona86 So how can I intercept gRPC calls and asynchronously validate their header tokens ?

beatkyo commented 6 years ago

You can capture the Context in your Future. As Context is immutable it is safe to pass around.

In the code below prevCtx is captured from the interceptor thread and made available in the Future thread.

abstract class FutureListener[Q](implicit ec: ExecutionContext) extends Listener[Q] {

  protected val delegate: Future[Listener[Q]]

  private val eventually = delegate.foreach _

  override def onComplete(): Unit = eventually { _.onComplete() }
  override def onCancel(): Unit = eventually { _.onCancel() }
  override def onMessage(message: Q): Unit = eventually { _ onMessage message }
  override def onHalfClose(): Unit = eventually { _.onHalfClose() }
  override def onReady(): Unit = eventually { _.onReady() }

}

object Keys {
  val AUTH_META_KEY: Metadata.Key[String] = of("auth-key", Metadata.ASCII_STRING_MARSHALLER)

  val ACCOUNT_CTX_KEY: Context.Key[String] = key("account")
}

class AuthorizationInterceptor(implicit ec: ExecutionContext) extends ServerInterceptor {

  override def interceptCall[Q, R](
      call: ServerCall[Q, R],
      headers: Metadata,
      next: ServerCallHandler[Q, R]
  ): Listener[Q] = {

    val prevCtx = Context.current
    val id = headers get AUTH_META_KEY take 6

    new FutureListener[Q] {
      protected val delegate = Future {
        val nextCtx = prevCtx withValue (ACCOUNT_CTX_KEY, s"user-$id")

        Contexts.interceptCall(nextCtx, call, headers, next)
      }
    }

  }

}
ostronom commented 6 years ago

This solution actually does nothing but harm -- it causes most of the requests to be cancelled with "half-closed" exception.