NielsRogge / Transformers-Tutorials

This repository contains demos I made with the Transformers library by HuggingFace.
MIT License
8.51k stars 1.34k forks source link

Donut inference is extremely slow when loading from checkpoint #276

Open MS1908 opened 1 year ago

MS1908 commented 1 year ago

I'm trying to use Donut model for document classification using my custom dataset (format similar to RVL-CDIP). When I train the model and run model inference (using model.generate() method) in the training loop for model evaluation, it is normal (inference for each image takes about 0.2s).

However, if after training, I save the model to checkpoint using the save_pretrained method, and then I load the checkpoint using the from_pretrained method, the model.generate() run extremely slow (6s ~ 7s).

Here is the code I use for inference (the code for inference in the training loop is exactly the same):

model = VisionEncoderDecoderModel.from_pretrained(CKPT_PATH, config=config)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)

accs = []
model.eval()
for i, sample in tqdm(enumerate(val_ds), total=len(val_ds)):
    pixel_values = sample["pixel_values"]
    pixel_values = torch.unsqueeze(pixel_values, 0)
    pixel_values = pixel_values.to(device)

    start = time.time()
    task_prompt = "<s_fci>"
    decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
    decoder_input_ids = decoder_input_ids.to(device)
    print(f"Tokenize time: {time.time() - start:.4f}s")

    start = time.time()
    outputs = model.generate(
        pixel_values,
        decoder_input_ids=decoder_input_ids,
        max_length=model.decoder.config.max_position_embeddings,
        early_stopping=True,
        pad_token_id=processor.tokenizer.pad_token_id,
        eos_token_id=processor.tokenizer.eos_token_id,
        use_cache=True,
        num_beams=1,
        bad_words_ids=[[processor.tokenizer.unk_token_id]],
        return_dict_in_generate=True,
    )
    print(f"Inference time: {time.time() - start:.4f}s")

    # turn into JSON
    start = time.time()
    seq = processor.batch_decode(outputs.sequences)[0]
    seq = seq.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
    seq = re.sub(r"<.*?>", "", seq, count=1).strip()  # remove first task start token
    seq = processor.token2json(seq)
    if "class" not in seq.keys():
        seq["class"] = "other"
    print(f"Decoding time: {time.time() - start:.4f}s")

    gt = sample["labels"]
    score = float(seq["class"] == gt["class"])

    accs.append(score)

acc_score = np.mean(accs)

print(f"Accuracy: {acc_score * 100:.4f}%")

I run the model on a NVIDIA A100 40GB GPU. I used an Anaconda environment with the following requirements:

cudatoolkit==11.7
torch==1.13.1+cu117
torchvision==0.14.1+cu117
datasets==2.10.1
transformers==4.26.1
sentencepiece==0.1.97
onnx==1.12.0
protobuf==3.20.0

Can you have a look at this issue @NielsRogge. Thank you very much.

Altimis commented 4 months ago

Hi @MS1908 Did you manage to find a solution for this slowness please ?