OpenNMT / CTranslate2

Fast inference engine for Transformer models
https://opennmt.net/CTranslate2
MIT License
3.28k stars 287 forks source link

Unable to use translate_batch of the NLLB model via multiprocessing #1414

Closed abhishekukmeesho closed 1 year ago

abhishekukmeesho commented 1 year ago

I'm basically trying to translate a list of sentences from English to one of the languages. I was planning to do this via multiprocessing and pass the sentences as a list. But whatever I do, the translate_batch function gets stuck if I call via pythons multiprocessing function and never proceeds further.

Eg: The below code works fine (Taken from one of the examples provided by you):

ct_model_path = "nllb-200-distilled-600M-int8"
sp_model_path = "flores200_sacrebleu_tokenizer_spm.model"
device = "cpu"

sp = spm.SentencePieceProcessor()
sp.load(sp_model_path)

translator = ctranslate2.Translator(ct_model_path, device)

src_lang = "eng_Latn"
tgt_lang = "kan_Knda"
beam_size = 1

source_sents = title_list[:50000] # Load 50,000 sentences in a list
source_sents = [sent.strip() for sent in source_sents]
target_prefix = [[tgt_lang]] * len(source_sents)

# Subword the source sentences
source_sents_subworded = sp.encode(source_sents, out_type=str)
source_sents_subworded = [[src_lang] + sent + ["</s>"] for sent in source_sents_subworded]

# Translate the source sentences
translator = ctranslate2.Translator(ct_model_path, device=device, inter_threads=16)
print(datetime.datetime.now())
translations = translator.translate_batch(source_sents_subworded, batch_type="tokens", max_batch_size=2026, beam_size=beam_size, target_prefix=target_prefix)
print(datetime.datetime.now())
translations = [translation.hypotheses[0] for translation in translations]

# Desubword the target sentences
translations_desubword = sp.decode(translations)
print(datetime.datetime.now())
translations_desubword = [sent[len(tgt_lang):] for sent in translations_desubword]
translations_desubword

But when I try the same via multiprocesing, the code gets stuck.

translator = ctranslate2.Translator(ct_model_path, device=device, inter_threads=16)
def get_batch_translation(title_batch):
    source_sents = [sent.strip() for sent in title_batch]
    target_prefix = [[tgt_lang]] * len(source_sents)

    # Subword the source sentences
    source_sents_subworded = sp.encode(source_sents, out_type=str)
    source_sents_subworded = [[src_lang] + sent + ["</s>"] for sent in source_sents_subworded]
    # Translate the source sentences
    # translator = ctranslate2.Translator(ct_model_path, device=device)
    translations = translator.translate_batch(source_sents_subworded, batch_type="tokens", max_batch_size=2024, beam_size=beam_size, target_prefix=target_prefix)
    translations = [translation.hypotheses[0] for translation in translations]
    # Desubword the target sentences
    translations_desubword = sp.decode(translations)
    translations_desubword = [sent[len(tgt_lang):] for sent in translations_desubword]

    return translations_desubword

def get_batch_translations_parallely(titles_batch):
    with Pool(5) as pool:
        results = pool.map(get_batch_translation, titles_batch)

    return results

results = get_batch_translations_parallely(title_batch_list[:5]) # Each element of title_batch_list is a list which has 10,000 titles in them

Is there any way I can include pythons multiprocessing into this? Basically the aim is to further reduce the time taken to translate. Currently it takes ~10mins to translate 50k sentences with my settings. Is there any way I can reduce this to <2 mins ? I tried translate_file also but it did not reduce the speed. Any help or suggestion regarding this is appreciated. TIA.

guillaumekln commented 1 year ago

The Translator instance cannot be copied to another process, so you should create one translator instance per process.

However, using translate_file or translate_iterable would be more efficient than using multiprocessing, especially when running the translations on CPU.

Can you post the code showing how you tried to use translate_file? How many CPU cores do you have?

abhishekukmeesho commented 1 year ago

Hi @guillaumekln , even if I copy the instantiation line inside the multiprocess function, the code still hangs. Eg:

def get_batch_translation(title_batch):
    source_sents = [sent.strip() for sent in title_batch]
    target_prefix = [[tgt_lang]] * len(source_sents)

    # Subword the source sentences
    source_sents_subworded = sp.encode(source_sents, out_type=str)
    source_sents_subworded = [[src_lang] + sent + ["</s>"] for sent in source_sents_subworded]
    # Translate the source sentences
    translator = ctranslate2.Translator(ct_model_path, device=device, inter_threads=16)
    translations = translator.translate_batch(source_sents_subworded, batch_type="tokens", max_batch_size=2024, beam_size=beam_size, target_prefix=target_prefix)
    translations = [translation.hypotheses[0] for translation in translations]
    # Desubword the target sentences
    translations_desubword = sp.decode(translations)
    translations_desubword = [sent[len(tgt_lang):] for sent in translations_desubword]

    return translations_desubword

def get_batch_translations_parallely(titles_batch):
    with Pool(5) as pool:
        results = pool.map(get_batch_translation, titles_batch)

    return results

results = get_batch_translations_parallely(title_batch_list[:5]) # Each element of title_batch_list is a list which has 10,000 titles in them
abhishekukmeesho commented 1 year ago

Also, sorry I got confused between translate_file and translate_iterable. I used the translate_iterable function where I just replaced the translate_batch with translate_iterable but the time taken was still the same. May be I am not using it the right way. Can you please point to to any examples you might have wrt translate_file and translate_iterable? I'll check them out. I am using an r5.4x EC2 machine which has 16 cores and 128GB memory.

guillaumekln commented 1 year ago

even if I copy the instantiation line inside the multiprocess function, the code still hangs.

It's possible the import should also be run in each process. In general the library is not fork-safe.

I used the translate_iterable function where I just replaced the translate_batch with translate_iterable but the time taken was still the same.

This will not change the performance indeed. The point of translate_iterable is that you don't need to build the batches yourself.

If you are translating a file the code would look like this:

# Use 16 worker threads, each using 1 computation thread.
translator = ctranslate2.Translator("model/", intra_threads=1, inter_threads=16)

with open("test.txt") as input_file:
    tokenize_fn = lambda line: line.strip().split()  # replace this by your tokenization.
    tokens = map(tokenize_fn, input_file)
    results = translator.translate_iterable(tokens)

    for result in results:
        print(result)

If the input stream is large enough, you should see all 16 cores fully used.