NielsRogge / Transformers-Tutorials

This repository contains demos I made with the Transformers library by HuggingFace.
MIT License
8.45k stars 1.32k forks source link

Can `VisionEncoderDecoderModel.generate()` work with batched data? #355

Open plamb-viso opened 8 months ago

plamb-viso commented 8 months ago

Sorry if this is the wrong place to post this.

I'm currently trying to finetune Donut using your excellent fine-tuning guide as a starting point. As a test, I am calling VisionEncoderDecoderModel.generate() like so:

outputs = model_generator.generate(
                batch['pixel_values'],
                decoder_input_ids=decoder_input_ids,
                max_length=model_generator.decoder.config.max_position_embeddings,
                early_stopping=True,
                pad_token_id=processor.tokenizer.pad_token_id,
                eos_token_id=processor.tokenizer.eos_token_id,
                use_cache=False,
                num_beams=1,
                bad_words_ids=[[processor.tokenizer.unk_token_id]],
                return_dict_in_generate=True,
            )

batch['pixel_values'] has a shape like so: pixel_values_shape=torch.Size([2, 3, 500, 647])

That is, two page images are passed to .generate(). As such, I was expecting outputs.sequences to have the starting dimension two but it is always one: outputs=torch.Size([1, 5]) output=tensor([57552, 57550, 57526, 57551, 2], device='cuda:3')

I realize I can iterate the batch and generate each sequence independently, but I'd be amazed if .generate() can't handle batch data like this. Is there something I'm missing?

NielsRogge commented 8 months ago

Hi,

Yes theoretically the VisionEncoderDecoderModel class should support batched generation. Could you open an issue regarding this on the Transformers (if there is no such issue yet)?

janakiram180 commented 2 months ago

def ocr_image(src_img): image_list = [] for image_path in batch_image: image = Image.open(image_path) image_list.append(image) pixel_values = processor(images=image_list, return_tensors="pt") generated_ids = model.generate(**pixel_values) print(processor.batch_decode(generated_ids, skip_special_tokens=True))

this code will work for image path