grpc / grpc-kotlin

Kotlin gRPC implementation. HTTP/2 based RPC
https://grpc.io/docs/languages/kotlin
Apache License 2.0
1.2k stars 165 forks source link

feat: capture call site coroutine context into call options #592

Open monosoul opened 8 months ago

monosoul commented 8 months ago

This PR introduces a change to how client coroutine stubs are generated.

It makes the generated code capture call site coroutine context into a call option. This is a first step to adding a suspendable client interceptor.

The change is backwards compatible and doesn't change any API of the generated stubs. In case if call options have no coroutine context an empty one will be returned.

linux-foundation-easycla[bot] commented 8 months ago

CLA Signed


The committers listed above are authorized under a signed CLA.

monosoul commented 8 months ago

Here's a diff for the stub used in tests:

Index: stub/build/generated/source/proto/test/grpckt/io/grpc/examples/helloworld/HelloWorldProtoGrpcKt.kt
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/stub/build/generated/source/proto/test/grpckt/io/grpc/examples/helloworld/HelloWorldProtoGrpcKt.kt b/stub/build/generated/source/proto/test/grpckt/io/grpc/examples/helloworld/HelloWorldProtoGrpcKt.kt
--- a/stub/build/generated/source/proto/test/grpckt/io/grpc/examples/helloworld/HelloWorldProtoGrpcKt.kt    
+++ b/stub/build/generated/source/proto/test/grpckt/io/grpc/examples/helloworld/HelloWorldProtoGrpcKt.kt    (date 1712076323265)
@@ -22,12 +22,16 @@
 import io.grpc.kotlin.ServerCalls.serverStreamingServerMethodDefinition
 import io.grpc.kotlin.ServerCalls.unaryServerMethodDefinition
 import io.grpc.kotlin.StubFor
+import io.grpc.kotlin.withCoroutineContext
 import kotlin.String
 import kotlin.coroutines.CoroutineContext
 import kotlin.coroutines.EmptyCoroutineContext
+import kotlin.coroutines.coroutineContext
 import kotlin.jvm.JvmOverloads
 import kotlin.jvm.JvmStatic
 import kotlinx.coroutines.flow.Flow
+import kotlinx.coroutines.flow.emitAll
+import kotlinx.coroutines.flow.flow

 /**
  * Holder for Kotlin coroutine-based client and server APIs for helloworld.Greeter.
@@ -84,7 +88,7 @@
       channel,
       GreeterGrpc.getSayHelloMethod(),
       request,
-      callOptions,
+      callOptions.withCoroutineContext(coroutineContext),
       headers
     )

@@ -112,7 +116,7 @@
       channel,
       GreeterGrpc.getClientStreamSayHelloMethod(),
       requests,
-      callOptions,
+      callOptions.withCoroutineContext(coroutineContext),
       headers
     )

@@ -130,13 +134,18 @@
      * @return A flow that, when collected, emits the responses from the server.
      */
     public fun serverStreamSayHello(request: MultiHelloRequest, headers: Metadata = Metadata()):
-        Flow<HelloReply> = serverStreamingRpc(
-      channel,
-      GreeterGrpc.getServerStreamSayHelloMethod(),
-      request,
-      callOptions,
-      headers
-    )
+        Flow<HelloReply> = 
+    flow {
+      emitAll(
+        serverStreamingRpc(
+          channel,
+          GreeterGrpc.getServerStreamSayHelloMethod(),
+          request,
+          callOptions.withCoroutineContext(coroutineContext),
+          headers
+        )
+      )
+    }

     /**
      * Returns a [Flow] that, when collected, executes this RPC and emits responses from the
@@ -159,13 +168,18 @@
      * @return A flow that, when collected, emits the responses from the server.
      */
     public fun bidiStreamSayHello(requests: Flow<HelloRequest>, headers: Metadata = Metadata()):
-        Flow<HelloReply> = bidiStreamingRpc(
-      channel,
-      GreeterGrpc.getBidiStreamSayHelloMethod(),
-      requests,
-      callOptions,
-      headers
-    )
+        Flow<HelloReply> = 
+    flow {
+      emitAll(
+        bidiStreamingRpc(
+          channel,
+          GreeterGrpc.getBidiStreamSayHelloMethod(),
+          requests,
+          callOptions.withCoroutineContext(coroutineContext),
+          headers
+        )
+      )
+    }
   }

   /**