caikit / caikit-nlp

Apache License 2.0
12 stars 45 forks source link

Do not surface `grpc.RpcError` from downstream calls to TGIS #354

Closed gabe-l-hart closed 3 months ago

gabe-l-hart commented 3 months ago

Describe the bug

The modules that call downstream to TGIS need to have consistent semantics with other local modules for errors. Currently, the tgis_utils utilities do not catch errors in the downstream calls, and neither do the modules that call them, so these errors are percolating up to the global try/except in the caikit.runtime server(s) when these modules are served.

When using the grpc_server, the global try/except already handles RpcError since it is a grpc server, but the http_server does not which results in a 500-class error to the user, even when the underlying TGIS exception is 400-class (e.g. INVALID_ARGUMENT). Since the http_server is generic, it should not be responsible for knowing about how the modules in caikit_nlp raise exceptions and should instead be able to expect that all exceptions are raised as either CaikitCoreException or CaikitRuntimeException.

Platform

Please provide details about the environment you are using, including the following:

Sample Code

import grpc
import caikit_tgis_backend
import caikit
from caikit_nlp.toolkit.text_generation.tgis_utils import TGISGenerationClient
from caikit.core.exceptions.caikit_core_exception import CaikitCoreException
from caikit.runtime.types.caikit_runtime_exception import CaikitRuntimeException
from concurrent.futures import ThreadPoolExecutor

class BadTgis(caikit_tgis_backend.protobufs.generation_pb2_grpc.GenerationServiceServicer):
    def Generate(self, request, context):
        context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
        context.set_details("Yikes!")
        return
    def GenerateStream(self, request, context):
        context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
        context.set_details("Streaming Yikes!")

# Build the server
svc = BadTgis()
pool = ThreadPoolExecutor(max_workers=1)
server = grpc.server(thread_pool=pool)
caikit_tgis_backend.protobufs.generation_pb2_grpc.add_GenerationServiceServicer_to_server(svc, server)
server.add_insecure_port("[::]:12345")
server.start()

# Create the utils client
client = caikit_tgis_backend.protobufs.generation_pb2_grpc.GenerationServiceStub(grpc.insecure_channel("localhost:12345"))
gen_client = TGISGenerationClient("foo", "bar", client, caikit.core.data_model.ProducerId("foobar"))

# This should raise a CaikitCoreException
try:
    gen_client.unary_generate(
        text="foobar",
        preserve_input_text=True,
        input_tokens=True,
        generated_tokens=True,
        token_logprobs=True,
        token_ranks=True,
        max_new_tokens=20,
        min_new_tokens=20,
        truncate_input_tokens=True,
        decoding_method="GREEDY",
        top_k=None,
        top_p=None,
        typical_p=None,
        temperature=None,
        seed=None,
        repetition_penalty=0.5,
        max_time=None,
        exponential_decay_length_penalty=None,
        stop_sequences=["asdf"],
    )
except (CaikitCoreException, CaikitRuntimeException) as err:
    print("Caught the right error")

Expected behavior

The above should catch the util error

Observed behavior

A grpc.RpcError is raised