huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
132.24k stars 26.34k forks source link

Pipeline to support batch inference #20973

Closed maiiabocharova closed 1 year ago

maiiabocharova commented 1 year ago

Feature request

Thank you for the awesome framework! For my work I wanted to use transformers.pipelines.token_classification.TokenClassificationPipeline in batch mode, since it is much faster on GPU, but I wanted to keep all the functionality for grouping entities. So I want to suggest something like this:

nlp = pipeline("ner", model=model, 
               tokenizer=tokenizer,
               device = 0 if torch.cuda.is_available() else -1,
               aggregation_strategy="average", batch_size=16)

Motivation

I implemented it for myself and think it would be cool to have this functionality "out-of-the-box" for community to enjoy the speed up. (And it really gives a huge speed up)

Your contribution

I am willing to contribute and implement this change for TokenClassification task (also for TextClassification, FeatureExtraction should be pretty much same). Have not worked with other pipelines, so not sure how batching is implemented there, but I am willing to try and contribute.

sgugger commented 1 year ago

cc @Narsil

Narsil commented 1 year ago

Hi @maiiabocharova doesn't this work already out of the box?

import torch
from transformers import pipeline

pipe = pipeline(
    "ner",
    device=0 if torch.cuda.is_available() else -1,
    aggregation_strategy="average",
    batch_size=16,
)

original_fn = pipe.model.forward
COUNT = 0

def new_forward(*args, **kwargs):
    global COUNT
    COUNT += 1
    return original_fn(*args, **kwargs)

pipe.model.forward = new_forward

def data():
    for i in range(20):
        yield "I live in New york"

for out in pipe(data()):
    print(out)

print(f"Forward called {COUNT} times")

This works, no ?

maiiabocharova commented 1 year ago

Sorry, probably I was looking into wrong source code

for i, sentence in enumerate(_inputs):

            # Manage correct placement of the tensors
            with self.device_placement():

                tokens = self.tokenizer(
                    sentence,
                    return_attention_mask=False,
                    return_tensors=self.framework,
                    truncation=True,
                    return_special_tokens_mask=True,
                    return_offsets_mapping=self.tokenizer.is_fast,
                )
                if self.tokenizer.is_fast:
                    offset_mapping = tokens.pop("offset_mapping").cpu().numpy()[0]
                elif offset_mappings:
                    offset_mapping = offset_mappings[i]
                else:
                    offset_mapping = None

                special_tokens_mask = tokens.pop("special_tokens_mask").cpu().numpy()[0]

But actually when I modified this part into

for start_index in range(0, len(sentences), batch_size):
            sentences_batch = sentences[start_index:start_index+batch_size]
            with self.device_placement():

                tokens = self.tokenizer(
                    sentences_batch,
                    return_attention_mask=False,
                    return_tensors=self.framework,
                    truncation=True,
                    padding='longest',
                    return_special_tokens_mask=True,
                    return_offsets_mapping=self.tokenizer.is_fast,
                )
                if self.tokenizer.is_fast:
                    offset_mapping_batch = tokens.pop("offset_mapping").cpu().numpy()
                special_tokens_mask_batch = tokens.pop("special_tokens_mask").cpu().numpy()
                with torch.no_grad():
                    tokens = self.ensure_tensor_on_device(**tokens)
                    entities_batch = self.model(**tokens)[0].cpu().numpy()
                    input_ids_batch = tokens["input_ids"].cpu().numpy()
                scores_batch = np.exp(entities_batch) / np.exp(entities_batch).sum(-1, keepdims=True)

Pipeline started working 3x faster

P.S. Yes, you are right! I am sorry, maybe I was using also the old version of the library. Sorry once again!

Narsil commented 1 year ago

Maybe an older version indeed.

Also the batching mecanism is not really transparent in the pipeline code, it's meant to be relatively orthogonal (because making it explicit had too many drawbacks, like code duplication, and it was really hard to support more complex use cases).