sksq96 / pytorch-summary

Model summary in PyTorch similar to `model.summary()` in Keras
MIT License
4.01k stars 412 forks source link

Failing in the input stage Using Encoder and Decoder Model Architecture (DONUT model) #192

Open swapnil-lader opened 1 year ago

swapnil-lader commented 1 year ago

I am trying to get a model summary of the donut model but am unable to define the input for the torch summary. ########################################################### import argparse import gradio as gr import torch from PIL import Image from donut.donut.model import DonutModel from torchvision import models from torchsummary import summary

def demo_process_vqa(input_img, question): global pretrained_model, task_prompt, task_name

pretrained_model = './donut/result/train_docvqa/20220912_103244'

# task_name = "docvqa"
# task_prompt = "<s_pdf-donut>"
input_img = Image.fromarray(input_img)
user_prompt = task_prompt.replace("{user_input}", question)
print(user_prompt)
output = pretrained_model.inference(input_img, prompt=user_prompt)["predictions"][0]
print('inf_out',output)
return output

def demo_process(input_img): global pretrained_model, task_prompt, task_name input_img = Image.fromarray(input_img) output = pretrained_model.inference(image=input_img, prompt=task_prompt)["predictions"][0] return output

parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="docvqa") parser.add_argument("--pretrained_path", type=str, default="train_docvqa_for_all_atts/donut/result/train_docvqa/20220915_125713") args, left_argv = parser.parse_known_args()

task_name = args.task if "docvqa" == task_name: task_prompt = "{user_input}" else: # rvlcdip, cord, ... taskprompt = f"<s{task_name}>"

pretrained_model = DonutModel.from_pretrained(args.pretrained_path)

if torch.cuda.is_available():

pretrained_model.half()

device = torch.device("cuda")
pretrained_model.to(device)

else: pretrained_model.encoder.to(torch.bfloat16)

summary(pretrained_model, [(1, 3, 1280 , 960), (1, 21),(1, 21)])

The shape of the encoder and decoder is as follows. Encoder : torch.Size([1, 3, 1280, 960]) Decode : torch.Size([1, 21])

Model forward architecture looks like this

    encoder_outputs = self.encoder(image_tensors)
    decoder_outputs = self.decoder(
        input_ids=decoder_input_ids,
        encoder_hidden_states=encoder_outputs,
        labels=decoder_labels,
    )
    return decoder_outputs

Can you please guide how to pass down the model input in summary?