awslabs / amazon-transcribe-streaming-sdk

The Amazon Transcribe Streaming SDK is an async Python SDK for converting audio into text via Amazon Transcribe.
Apache License 2.0
138 stars 38 forks source link

Support for AWS Transcribe Medical #24

Open tmarice opened 3 years ago

tmarice commented 3 years ago

Hello,

since AWS released the Medical version of the Transcribe service, it would be great if this SDK natively supported that option too. Since the APIs are very similar, we managed to hack together an ugly version of TranscribeMedicalStreamingClient by just inheriting from TranscribeStreamingClient and performing similar hacks for TranscribeMedicalStreamingRequestSerializer and StartMedicalStreamTranscriptionRequest:

from amazon_transcribe.client import TranscribeStreamingClient
from amazon_transcribe.handlers import TranscriptResultStreamHandler
from amazon_transcribe.httpsession import AwsCrtHttpSessionManager
from amazon_transcribe.model import StartStreamTranscriptionEventStream, StartStreamTranscriptionRequest
from amazon_transcribe.serialize import HEADER_VALUE, Serializer, TranscribeStreamingRequestSerializer
from amazon_transcribe.signer import SigV4RequestSigner
from amazon_transcribe.structures import BufferableByteStream
from amazon_transcribe.utils import _add_required_headers

##

class StartMedicalStreamTranscriptionRequest(StartStreamTranscriptionRequest):
    def __init__(self, *args, **kwargs):
        audio_type = kwargs.pop("audio_type")
        specialty = kwargs.pop("specialty")

        super().__init__(*args, **kwargs)

        self.audio_type = audio_type
        self.specialty = specialty

##

class TranscribeMedicalStreamingRequestSerializer(TranscribeStreamingRequestSerializer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.request_uri = "/medical-stream-transcription"

    def serialize(self) -> Tuple[Dict[str, HEADER_VALUE], BufferedIOBase]:
        headers = {
            "x-amzn-transcribe-language-code": self.request_shape.language_code,
            "x-amzn-transcribe-sample-rate": self.request_shape.media_sample_rate_hz,
            "x-amzn-transcribe-media-encoding": self.request_shape.media_encoding,
            "x-amzn-transcribe-vocabulary-name": self.request_shape.vocabulary_name,
            "x-amzn-transcribe-session-id": self.request_shape.session_id,
            "x-amzn-transcribe-vocabulary-filter-method": self.request_shape.vocab_filter_method,
            "x-amzn-transcribe-vocabulary-filter-name": self.request_shape.vocab_filter_name,
            "x-amzn-transcribe-show-speaker-label": self.request_shape.show_speaker_label,
            "x-amzn-transcribe-enable-channel-identification": self.request_shape.enable_channel_identification,
            "x-amzn-transcribe-number-of-channels": self.request_shape.number_of_channels,
            "x-amzn-transcribe-specialty": self.request_shape.specialty,
            "x-amzn-transcribe-type": self.request_shape.audio_type,
        }

        _add_required_headers(self.endpoint, headers)

        body = BufferableByteStream()
        return headers, body

##

class TranscribeMedicalStreamingClient(TranscribeStreamingClient):
    async def start_stream_transcription(
        self,
        *,
        language_code: str,
        media_sample_rate_hz: int,
        media_encoding: str,
        audio_type: str,
        specialty: str,
        vocabulary_name: Optional[str] = None,
        session_id: Optional[str] = None,
        vocab_filter_method: Optional[str] = None,
        vocab_filter_name: Optional[str] = None,
        show_speaker_label: Optional[bool] = None,
        enable_channel_identification: Optional[bool] = None,
        number_of_channels: Optional[int] = None,
    ) -> StartStreamTranscriptionEventStream:
        transcribe_streaming_request = StartMedicalStreamTranscriptionRequest(
            language_code,
            media_sample_rate_hz,
            media_encoding,
            vocabulary_name,
            session_id,
            vocab_filter_method,
            vocab_filter_name,
            show_speaker_label,
            enable_channel_identification,
            number_of_channels,
            audio_type=audio_type,
            specialty=specialty,
        )
        endpoint = await self._endpoint_resolver.resolve(self.region)
        self._serializer: Serializer = TranscribeMedicalStreamingRequestSerializer(
            endpoint=endpoint,
            transcribe_request=transcribe_streaming_request,
        )
        request = self._serializer.serialize_to_request()

        creds = await self._credential_resolver.get_credentials()
        signer = SigV4RequestSigner("transcribe", self.region)
        signed_request = signer.sign(request, creds)

        session = AwsCrtHttpSessionManager(self._eventloop)

        response = await session.make_request(
            signed_request.uri,
            method=signed_request.method,
            headers=signed_request.headers.as_list(),
            body=signed_request.body,
        )
        resolved_response = await response.resolve_response()

        status_code = resolved_response.status_code
        if status_code >= 400:
            # We need to close before we can consume the body or this will hang
            signed_request.body.close()
            body_bytes = await response.consume_body()
            raise self._response_parser.parse_exception(resolved_response, body_bytes)
        elif status_code != 200:
            raise RuntimeError("Unexpected status code encountered: %s" % status_code)

        parsed_response = self._response_parser.parse_start_stream_transcription_response(
            resolved_response,
            response,
        )

        # The audio stream is returned as output because it requires
        # the signature from the initial HTTP request to be useable
        audio_stream = self._create_audio_stream(signed_request)
        return StartStreamTranscriptionEventStream(audio_stream, parsed_response)
mikeballou-augmedix commented 3 years ago

Thanks @tmarice! This is awesome. Your code needed some modifications to work with the current latest SDK (0.4.0).

from typing import Optional

from amazon_transcribe.client import TranscribeStreamingClient
from amazon_transcribe.httpsession import AwsCrtHttpSessionManager
from amazon_transcribe.model import StartStreamTranscriptionEventStream, StartStreamTranscriptionRequest
from amazon_transcribe.serialize import TranscribeStreamingSerializer
from amazon_transcribe.signer import SigV4RequestSigner
from amazon_transcribe.request import Request

##

class StartMedicalStreamTranscriptionRequest(StartStreamTranscriptionRequest):
    def __init__(self, *args, **kwargs):
        audio_type = kwargs.pop("audio_type")
        specialty = kwargs.pop("specialty")

        super().__init__(*args, **kwargs)

        self.audio_type = audio_type
        self.specialty = specialty

##

class TranscribeMedicalStreamingSerializer(TranscribeStreamingSerializer):
    def __init__(self):
        super().__init__()

        self.request_uri = "/medical-stream-transcription"

    def serialize_start_stream_transcription_request(
        self, endpoint: str, request_shape: StartStreamTranscriptionRequest
    ) -> Request:
        request = super().serialize_start_stream_transcription_request(endpoint, request_shape)
        request.path = self.request_uri

        request.headers.update(
            super()._serialize_str_header(
                "specialty", request_shape.specialty
            )
        )

        request.headers.update(
            super()._serialize_str_header(
                "type", request_shape.audio_type
            )
        )

        return request

##

class TranscribeMedicalStreamingClient(TranscribeStreamingClient):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._serializer = TranscribeMedicalStreamingSerializer()

    async def start_stream_transcription(
        self,
        *,
        language_code: str,
        media_sample_rate_hz: int,
        media_encoding: str,
        audio_type: str,
        specialty: str,
        vocabulary_name: Optional[str] = None,
        session_id: Optional[str] = None,
        vocab_filter_method: Optional[str] = None,
        vocab_filter_name: Optional[str] = None,
        show_speaker_label: Optional[bool] = None,
        enable_channel_identification: Optional[bool] = None,
        number_of_channels: Optional[int] = None,
    ) -> StartStreamTranscriptionEventStream:
        transcribe_streaming_request = StartMedicalStreamTranscriptionRequest(
            language_code,
            media_sample_rate_hz,
            media_encoding,
            vocabulary_name,
            session_id,
            vocab_filter_method,
            vocab_filter_name,
            show_speaker_label,
            enable_channel_identification,
            number_of_channels,
            audio_type=audio_type,
            specialty=specialty,
        )
        endpoint = await self._endpoint_resolver.resolve(self.region)

        ## super
        request = self._serializer.serialize_start_stream_transcription_request(
            endpoint=endpoint, request_shape=transcribe_streaming_request,
        ).prepare()

        creds = await self._credential_resolver.get_credentials()
        signer = SigV4RequestSigner("transcribe", self.region)
        signed_request = signer.sign(request, creds)

        session = AwsCrtHttpSessionManager(self._eventloop)

        response = await session.make_request(
            signed_request.uri,
            method=signed_request.method,
            headers=signed_request.headers.as_list(),
            body=signed_request.body,
        )
        resolved_response = await response.resolve_response()

        status_code = resolved_response.status_code
        if status_code >= 400:
            # We need to close before we can consume the body or this will hang
            signed_request.body.close()
            body_bytes = await response.consume_body()
            raise self._response_parser.parse_exception(resolved_response, body_bytes)
        elif status_code != 200:
            raise RuntimeError("Unexpected status code encountered: %s" % status_code)

        parsed_response = self._response_parser.parse_start_stream_transcription_response(
            resolved_response,
            response,
        )

        # The audio stream is returned as output because it requires
        # the signature from the initial HTTP request to be useable
        audio_stream = self._create_audio_stream(signed_request)
        return StartStreamTranscriptionEventStream(audio_stream, parsed_response)
david-oliveira-br commented 2 years ago

Hey guys, any updates about that?

vikramsubramanian commented 1 year ago

Any updates? Surprised this already hasn't been built

alexe0336 commented 1 month ago

Any updates? Also I'd like support for show_speaker_labels.