huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
128.5k stars 25.49k forks source link

Using batching with pipeline and transformers #31641

Open arunasank opened 6 days ago

arunasank commented 6 days ago

System Info

Who can help?

@ArthurZucker @Narsil @stevhliu

Information

Tasks

Reproduction

import pandas as pd
import utils
import datasets
import bitsandbytes
from torch.utils.data import DataLoader
from accelerate import infer_auto_device_map
from transformers import pipeline
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from transformers.pipelines.pt_utils import KeyDataset
from tqdm.auto import tqdm
import os
import torch
assert torch.cuda.is_available()

BATCH_SIZE = 2
os.environ['TRANSFORMERS_CACHE'] = '<redacted>'
os.environ['HF_TOKEN']='<redacted>'

dataset_filename: str = "<redacted>"

# Only for Llama-70B-chat
nf4_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

# Load dataset
dataset = utils.load_from_jsonl(dataset_filename)
dataframe = pd.DataFrame(dataset[:4])
dataset = datasets.Dataset.from_pandas(dataframe)

pipe = pipeline("text-generation", model="meta-llama/Llama-2-70b-chat-hf", device_map='auto', max_new_tokens=1024, model_kwargs={"cache_dir": os.environ['TRANSFORMERS_CACHE'], "quantization_config": nf4_config})
pipe.tokenizer.pad_token_id = pipe.model.config.eos_token_id

responses = []
for d in tqdm(pipe(KeyDataset(dataset, 'prompt'), batch_size=BATCH_SIZE)):
    print('here!')
    responses.extend(d)

# Save responses
with open('llama-70-responses.jsonl', 'a') as f:
    for response in responses:
        f.write(response['generated_text'] + "\n")

print("Processing complete.")
  1. Code hangs at for d in tqdm(pipe(KeyDataset(dataset, 'prompt'), batch_size=BATCH_SIZE)): and doesn't progress. No errors are thrown.
  2. Code works when batch_size is not provided in the above line, but tqdm only shows progress on completion.
  3. Alternatively, I tried to provide the batch_size with this line pipe = pipeline("text-generation", model="meta-llama/Llama-2-70b-chat-hf", device_map='auto', max_new_tokens=1024, batch_size=BATCH_SIZE, model_kwargs={"cache_dir": os.environ['TRANSFORMERS_CACHE'], "quantization_config": nf4_config}) but it doesn't work.
  4. Example from https://huggingface.co/docs/transformers/en/main_classes/pipelines#pipeline-batching does not work with Llama-2-79b-chat-hf

Expected behavior

I would expect the code to work with batching. Clear documentation on batching when using pipelines would be appreciated.

amyeroberts commented 6 days ago

cc @Rocketknight1

arunasank commented 6 days ago

UPDATE: I got the code to work with batching. It's just very hard to gauge progress because tqdm does not report progress until the whole pipeline has finished the task. Even so, times are not reported, making it impossible to to sub-sample and determine time either.

TL;DR - 1, 3, and 4 are resolved and 2 still remains as an issue.

UPDATE: It doesn't work! It works for small datasets and smaller batch sizes, but fails silently with no errors on a large dataset with a reasonable batch size. For example, the script uses BATCH_SIZE =2 and a dataset with 4 inputs. This took 4 minutes to run with Llama-2-chat on an A100. When I considered the full dataset, which is ~7200 inputs with a batch size of 4, it was hanging on the for d in tqdm(pipe(KeyDataset(dataset, 'prompt'), batch_size=BATCH_SIZE)): line silently, with no errors.

not-lain commented 6 days ago

@arunasank have you tried using the map method ?

def fn(batch) : 
     batch["new_col"] = pipe(batch["prompt"]) 
     return batch

dataset = dataset.map(fn, batched = True , batch_size = BATCH_SIZE )
responses = dataset["new_col"]
arunasank commented 6 days ago

Yes, I did that with/without setting batch-size on the pipeline method and it didn't work!

Rocketknight1 commented 5 days ago

Hi @arunasank, have you tried other models and other methods of dataset creation to see if the issue recurs? It'd be helpful if you could figure out some minimal specific reproducer code for this issue, so that we could run it here and figure out what's going on!