huggingface / transformers

šŸ¤— Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
134.07k stars 26.81k forks source link

Support for tokenising higher batch dimensions #27806

Closed saswat0 closed 9 months ago

saswat0 commented 11 months ago

Feature request

Can inference be performed with larger batch dimensions? Currently, tokenisation is supported upto List[List[str]] and any dimension higher than that needs to be traversed via a for loop. Can we reduce this requirement and process the entire batch in one go?

Motivation

Tokenisation for a batch of batch of strings has to be done sequentially by iterating over each string. On a GPU, it should be possible to do this in one go, thereby reducing latency and improving utilisation by a large margin.

Currently, the following snippet works fine

import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-reranker-large')
model = AutoModelForSequenceClassification.from_pretrained('BAAI/bge-reranker-large')
model.eval()

pairs = [['what is panda?', 'hi'], ['what is panda?', 'The giant panda is a bear species endemic to China.']]
with torch.no_grad():
    inputs = tokenizer(pairs, padding=True, truncation=True, return_tensors='pt', max_length=512)
    scores = model(**inputs, return_dict=True).logits.view(-1, ).float()
    print(scores)

But if while encoding a batch of strings for comparison, it fails.

query_list = [[['what is panda?', 'hi'], ['what is panda?', 'The giant panda is a bear species endemic to China.']] for _ in range(4)]

Following is the error (irrespective of the tokenizer)

File ~/anaconda3/envs/reranking_project/lib/python3.10/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

File ~/anaconda3/envs/reranking_project/lib/python3.10/site-packages/FlagEmbedding/flag_models.py:149, in FlagReranker.compute_score(self, sentence_pairs, batch_size, max_length)
    146 for start_index in tqdm(range(0, len(sentence_pairs), batch_size), desc=\"Compute Scores\",
    147                         disable=len(sentence_pairs) < 128):
    148     sentences_batch = sentence_pairs[start_index:start_index + batch_size]
--> 149     inputs = self.tokenizer(
    150         sentences_batch,
    151         padding=True,
    152         truncation=True,
    153         return_tensors='pt',
    154         max_length=max_length,
    155     ).to(self.device)
    157     scores = self.model(**inputs, return_dict=True).logits.view(-1, ).float()
    158     all_scores.extend(scores.cpu().numpy().tolist())

File ~/anaconda3/envs/reranking_project/lib/python3.10/site-packages/transformers/tokenization_utils_base.py:2602, in PreTrainedTokenizerBase.__call__(self, text, text_pair, text_target, text_pair_target, add_special_tokens, padding, truncation, max_length, stride, is_split_into_words, pad_to_multiple_of, return_tensors, return_token_type_ids, return_attention_mask, return_overflowing_tokens, return_special_tokens_mask, return_offsets_mapping, return_length, verbose, **kwargs)
   2600     if not self._in_target_context_manager:
   2601         self._switch_to_input_mode()
-> 2602     encodings = self._call_one(text=text, text_pair=text_pair, **all_kwargs)
   2603 if text_target is not None:
   2604     self._switch_to_target_mode()

File ~/anaconda3/envs/reranking_project/lib/python3.10/site-packages/transformers/tokenization_utils_base.py:2660, in PreTrainedTokenizerBase._call_one(self, text, text_pair, add_special_tokens, padding, truncation, max_length, stride, is_split_into_words, pad_to_multiple_of, return_tensors, return_token_type_ids, return_attention_mask, return_overflowing_tokens, return_special_tokens_mask, return_offsets_mapping, return_length, verbose, **kwargs)
   2657         return False
   2659 if not _is_valid_text_input(text):
-> 2660     raise ValueError(
   2661         \"text input must of type `str` (single example), `List[str]` (batch or single pretokenized example) \"
   2662         \"or `List[List[str]]` (batch of pretokenized examples).\"
   2663     )
   2665 if text_pair is not None and not _is_valid_text_input(text_pair):
   2666     raise ValueError(
   2667         \"text input must of type `str` (single example), `List[str]` (batch or single pretokenized example) \"
   2668         \"or `List[List[str]]` (batch of pretokenized examples).\"
   2669     )

ValueError: text input must of type `str` (single example), `List[str]` (batch or single pretokenized example) or `List[List[str]]` (batch of pretokenized examples)."
}

Is there any way to perform inference at once for a List[List[List[str]]] without iterating one by one?

Your contribution

No contribution. Just raising a request.

ArthurZucker commented 10 months ago

You should probably pack your data, will also be more efficient I think. Separate them with the EOS

saswat0 commented 10 months ago

@ArthurZucker I'm not sure if I got your point entirely. For now, I'm flattening the whole batch of tensors before passing in into the model and then reshaping it after getting the output. Do you mean that I concatenate the sentences together with a token before tokenising them? If so, how do I represent the sentence pairs (i.e., concatenate across which axis)?

ArthurZucker commented 10 months ago

What I mean is that you cannot pass list of list of list of string to a tokenizer. What you need to do is to seperate with a special token like <SEP> for example. I am not sure I understand your last comment as the tokenizer does not take tensors as an input. Models usually support batches of input so no need to flattent but if it works alrifht.