AssemblyAI / assemblyai-python-sdk

AssemblyAI's Official Python SDK
https://assemblyai.com
MIT License
135 stars 16 forks source link

Transcribe Group will yield results in a different order that inputs #79

Open Ruben-Kruepper opened 1 month ago

Ruben-Kruepper commented 1 month ago

Calling transcriptions = transcriber.transcribe_group([str(c.audio_file) for c in audio_chunks]) may yield the results out-of-order, i.e., transcriptions[i] does not correspond to audio_chunks[i]. Is this intended? If yes, how are we supposed to know what corresponds to what?

ploeber commented 1 month ago

Hi, thanks for flagging! Yes, it's using concurrent.futures.wait() under the hood so it does not preserve the order. It will return results in the order they become available.

Unfortunately, I don't see a way to connect inputs with outputs the way it is implemented right now, so you'd have to write your own logic to submit multiple files concurrently and preserve the order.

We can try to work on an enhancement that makes sure the order is preserved. If not, at the very least this should be mentioned in the function docstring.

serozhenka commented 1 month ago

@Ruben-Kruepper As a temporary workaround, I have used the following code:

video_urls: list[str] = [video_url for _ in range(5)]
futures = [
    transcriber.transcribe_async(video_url, config=config)
    for video_url in video_urls
]
for video_url, future in zip(video_urls, futures):
    setattr(future, "__video_url__", video_url)

futures = wait(futures).done
for future in futures:
    print(getattr(future, "__video_url__"), future.result().text)

Agree that it's not as clear as wanted, but it is what it its :)

Ruben-Kruepper commented 1 month ago

@serozhenka Good one, very compact! We're using this, which maps quite closely onto what the SDK is doing under the hood right now.

import threading
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Any, Callable, Iterable, List

class Dispatcher:
    def __init__(
        self,
        action: Callable,
        calls_per_minute: int,
        concurrent_calls: int,
        progress_cb: Callable = None,
        description: str = None,
    ):
        self.action = action
        self.calls_per_minute = calls_per_minute
        self.concurrent_calls = concurrent_calls
        self.interval = 60 / calls_per_minute
        self.last_call_time = 0
        self.lock = threading.Lock()
        self.progress_cb = progress_cb
        self.description = description
        self.tracking_id = 0

    def _rate_limited_action(self, arg: Any) -> Any:
        with self.lock:
            current_time = time.time()
            time_since_last_call = current_time - self.last_call_time
            if time_since_last_call < self.interval:
                time.sleep(self.interval - time_since_last_call)
            self.last_call_time = time.time()

        return self.action(arg)

    def bulk_process(self, to_process: Iterable[Any]) -> List[Any]:
        to_process = list(to_process)  # Convert iterable to list to ensure indexing
        results = [None] * len(to_process)  # Pre-allocate result list
        n_total = len(to_process)

        def process_with_index(index_and_arg):
            index, arg = index_and_arg
            result = self._rate_limited_action(arg)
            results[index] = result
            self.tracking_id += 1
            if self.progress_cb:
                desc = self.description or "Processing"
                self.progress_cb(desc, self.tracking_id / n_total)

        if self.progress_cb:
            desc = self.description or "Processing"
            self.progress_cb(desc, 0)

        with ThreadPoolExecutor(max_workers=self.concurrent_calls) as executor:
            list(executor.map(process_with_index, enumerate(to_process)))

        return results

Which is also very useful for general parallel API calls 😁