grpc / grpc-kotlin

Kotlin gRPC implementation. HTTP/2 based RPC
https://grpc.io/docs/languages/kotlin
Apache License 2.0
1.2k stars 166 forks source link

Provide Kotlin friendly, i.e., coroutine, API in interceptors? #223

Open kvn-esque opened 3 years ago

kvn-esque commented 3 years ago

Looking at the interceptor API, it seems the only way to perform async non-blocking IO work is something along the lines of https://stackoverflow.com/questions/53651024/grpc-java-async-call-in-serverinterceptor

Is it possible to provide a more Kotlin friendly API for interceptors?

kvn-esque commented 3 years ago

Any thoughts? @jamesward @lowasser

jamesward commented 3 years ago

This would be one for @lowasser

lowasser commented 3 years ago

Does Java provide any async API for interceptors? I don't think it does.

kvn-esque commented 3 years ago

This seems like it's an async API for gRPC in Java land

https://grpc.github.io/grpc-java/javadoc/io/grpc/ServerCall.Listener.html

lowasser commented 3 years ago

Yes, but that's for the RPC part, not interceptors.

Java interceptors appear strictly asynchronous.

lns-ross commented 2 years ago

It's been a while and this is still an open question (and I hate to be 'that person') but ... was there any consensus or conclusion (good or bad) for a Kotlin-friendly approach? We'd rather do it by convention in the calls than use the solution offered up in the link in the OP. At least then we can do coroutine-based calls out to sibling gRPC services. TIA.

be-hase commented 2 years ago

I think this Issue proposal is very good. Therefore, I have implemented the following interceptor. I was also able to confirm that the example code works.

/**
 * https://stackoverflow.com/questions/53651024/grpc-java-async-call-in-serverinterceptor
 */
abstract class SuspendableServerInterceptor(
    private val context: CoroutineContext = EmptyCoroutineContext
) : ServerInterceptor {
    override fun <ReqT : Any, RespT : Any> interceptCall(
        call: ServerCall<ReqT, RespT>,
        headers: Metadata,
        next: ServerCallHandler<ReqT, RespT>
    ): ServerCall.Listener<ReqT> {
        val delayedListener = DelayedListener<ReqT>()
        delayedListener.job = CoroutineScope(
            GrpcContextElement.current()
                    + COROUTINE_CONTEXT_KEY.get()
                    + context
        ).launch {
            try {
                delayedListener.realListener = suspendableInterceptCall(call, headers, next)
                delayedListener.drainPendingCallbacks()
            } catch (e: CancellationException) {
                log.debug { "Caught CancellationException. $e" }
                call.close(Status.CANCELLED, Metadata())
            } catch (e: Exception) {
                log.error(e) { "Unhandled exception. $e" }
                call.close(Status.UNKNOWN, Metadata())
            }
        }
        return delayedListener
    }

    abstract suspend fun <ReqT : Any, RespT : Any> suspendableInterceptCall(
        call: ServerCall<ReqT, RespT>,
        headers: Metadata,
        next: ServerCallHandler<ReqT, RespT>
    ): ServerCall.Listener<ReqT>

    /**
     * ref: https://github.com/grpc/grpc-java/blob/84edc332397ed01fae2400c25196fc90d8c1a6dd/core/src/main/java/io/grpc/internal/DelayedClientCall.java#L415
     */
    private class DelayedListener<ReqT> : ServerCall.Listener<ReqT>() {
        var realListener: ServerCall.Listener<ReqT>? = null

        @Volatile
        private var passThrough = false

        @GuardedBy("this")
        private var pendingCallbacks: MutableList<Runnable> = mutableListOf()

        var job: Job? = null

        override fun onMessage(message: ReqT) {
            if (passThrough) {
                checkNotNull(realListener).onMessage(message)
            } else {
                delayOrExecute { checkNotNull(realListener).onMessage(message) }
            }
        }

        override fun onHalfClose() {
            if (passThrough) {
                checkNotNull(realListener).onHalfClose()
            } else {
                delayOrExecute { checkNotNull(realListener).onHalfClose() }
            }
        }

        override fun onCancel() {
            job?.cancel()
            if (passThrough) {
                checkNotNull(realListener).onCancel()
            } else {
                delayOrExecute { checkNotNull(realListener).onCancel() }
            }
        }

        override fun onComplete() {
            if (passThrough) {
                checkNotNull(realListener).onComplete()
            } else {
                delayOrExecute { checkNotNull(realListener).onComplete() }
            }
        }

        override fun onReady() {
            if (passThrough) {
                checkNotNull(realListener).onReady()
            } else {
                delayOrExecute { checkNotNull(realListener).onReady() }
            }
        }

        private fun delayOrExecute(runnable: Runnable) {
            synchronized(this) {
                if (!passThrough) {
                    pendingCallbacks.add(runnable)
                    return
                }
            }
            runnable.run()
        }

        fun drainPendingCallbacks() {
            check(!passThrough)
            var toRun: MutableList<Runnable> = mutableListOf()
            while (true) {
                synchronized(this) {
                    if (pendingCallbacks.isEmpty()) {
                        pendingCallbacks = mutableListOf()
                        passThrough = true
                        return
                    }
                    // Since there were pendingCallbacks, we need to process them. To maintain ordering we
                    // can't set passThrough=true until we run all pendingCallbacks, but new Runnables may be
                    // added after we drop the lock. So we will have to re-check pendingCallbacks.
                    val tmp: MutableList<Runnable> = toRun
                    toRun = pendingCallbacks
                    pendingCallbacks = tmp
                }
                for (runnable in toRun) {
                    // Avoid calling listener while lock is held to prevent deadlocks.
                    runnable.run()
                }
                toRun.clear()
            }
        }
    }

    companion object {
        private val log = KotlinLogging.logger {}

        @Suppress("UNCHECKED_CAST")
        // Get by using reflection
        internal val COROUTINE_CONTEXT_KEY: Context.Key<CoroutineContext> =
            CoroutineContextServerInterceptor::class.let { kclass ->
                val companionObject = kclass.companionObject!!
                val property = companionObject.memberProperties.single { it.name == "COROUTINE_CONTEXT_KEY" }
                checkNotNull(property.getter.call(kclass.companionObjectInstance!!)) as Context.Key<CoroutineContext>
            }
    }
}
jetaggart commented 2 years ago

@be-hase can you provide your imports? I'm having a hard time seeing where a few things come from.

gregkonush commented 1 year ago

Wonder if there is any easy approach to this, would be nice if the library provides this natively.