clovaai / donut

Official Implementation of OCR-free Document Understanding Transformer (Donut) and Synthetic Document Generator (SynthDoG), ECCV 2022
https://arxiv.org/abs/2111.15664
MIT License
5.74k stars 466 forks source link

Input type (float) and bias type (struct c10::BFloat16) should be the same #267

Open Coder-Vishali opened 11 months ago

Coder-Vishali commented 11 months ago

When I try to execute the below code:

_``` from donut import DonutModel import torch from PIL import Image

pretrained_model = DonutModel.from_pretrained("naver-clova-ix/donut-base") if torch.cuda.is_available(): pretrained_model.half() device = torch.device("cuda") pretrained_model.to(device) else: pretrained_model.encoder.to(torch.bfloat16) pretrained_model.eval()

task_name = "synthdog" taskprompt = f"<s{task_name}>"

input_img = Image.open(r"C:\Project_Files\donut_vqa\image\1_GREEN_crop_31.jpg") output = pretrained_model.inference(image=input_img, prompt=task_prompt)["predictions"][0] print(output)



I get the below error:

> RuntimeError                              Traceback (most recent call last)
> [c:\Project_Files\donut_vqa\colab_demo_for_donut_base_finetuned_docvqa_230615.ipynb](file:///C:/Project_Files/donut_vqa/colab_demo_for_donut_base_finetuned_docvqa_230615.ipynb) Cell 7 line 1
>      [16](vscode-notebook-cell:/c%3A/Project_Files/donut_vqa/colab_demo_for_donut_base_finetuned_docvqa_230615.ipynb#W6sZmlsZQ%3D%3D?line=15) task_prompt = f"<s_{task_name}>"
>      [18](vscode-notebook-cell:/c%3A/Project_Files/donut_vqa/colab_demo_for_donut_base_finetuned_docvqa_230615.ipynb#W6sZmlsZQ%3D%3D?line=17) input_img = Image.open(r"C:\Project_Files\donut_vqa\image\crop_31.jpg")
> ---> [19](vscode-notebook-cell:/c%3A/Project_Files/donut_vqa/colab_demo_for_donut_base_finetuned_docvqa_230615.ipynb#W6sZmlsZQ%3D%3D?line=18) output = pretrained_model.inference(image=input_img, prompt=task_prompt)["predictions"][0]
>      [20](vscode-notebook-cell:/c%3A/Project_Files/donut_vqa/colab_demo_for_donut_base_finetuned_docvqa_230615.ipynb#W6sZmlsZQ%3D%3D?line=19) print(output)
> 
> File [c:\Project_Files\donut_vqa\.venv\lib\site-packages\donut\model.py:452](file:///C:/Project_Files/donut_vqa/.venv/lib/site-packages/donut/model.py:452), in DonutModel.inference(self, image, prompt, image_tensors, prompt_tensors, return_json, return_attentions)
>     448     prompt_tensors = self.decoder.tokenizer(prompt, add_special_tokens=False, return_tensors="pt")["input_ids"]
>     450 prompt_tensors = prompt_tensors.to(self.device)
> --> 452 last_hidden_state = self.encoder(image_tensors)
>     453 if self.device.type != "cuda":
>     454     last_hidden_state = last_hidden_state.to(torch.float32)
> 
> File [c:\Project_Files\donut_vqa\.venv\lib\site-packages\torch\nn\modules\module.py:1518](file:///C:/Project_Files/donut_vqa/.venv/lib/site-packages/torch/nn/modules/module.py:1518), in Module._wrapped_call_impl(self, *args, **kwargs)
>    1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
>    1517 else:
> -> 1518     return self._call_impl(*args, **kwargs)
> 
> File [c:\Project_Files\donut_vqa\.venv\lib\site-packages\torch\nn\modules\module.py:1527](file:///C:/Project_Files/donut_vqa/.venv/lib/site-packages/torch/nn/modules/module.py:1527), in Module._call_impl(self, *args, **kwargs)
>    1522 # If we don't have any hooks, we want to skip the rest of the logic in
>    1523 # this function, and just call forward.
>    1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
>    1525         or _global_backward_pre_hooks or _global_backward_hooks
> ...
>     455                     _pair(0), self.dilation, self.groups)
> --> 456 return F.conv2d(input, weight, bias, self.stride,
>     457                 self.padding, self.dilation, self.groups)
> 
> RuntimeError: Input type (float) and bias type (struct c10::BFloat16) should be the same

What should be the input image shape?
The image which I use is in shape: (19, 273, 3)
bugface commented 10 months ago

check your package versions. You might want to stay with the exact versions listed in the project requirements.txt

skittoo commented 9 months ago

check your package versions. You might want to stay with the exact versions listed in the project requirements.txt

I think this problem occurs when you are using CPU instead of GPU.

skittoo commented 9 months ago

Use GPU instead of CPU and try again

java2python commented 6 months ago

I encountered the same problem while using Python 3.11 and the following contents in requirements.txt.

- aiofiles==23.2.1
- aiohttp==3.9.3
- aiosignal==1.3.1
- altair==5.2.0
- annotated-types==0.6.0
- anyio==4.3.0
- asgiref==3.8.0
- attrs==23.2.0
- certifi==2024.2.2
- charset-normalizer==3.3.2
- click==8.1.7
- colorama==0.4.6
- contourpy==1.2.0
- cycler==0.12.1
- datasets==2.18.0
- dill==0.3.8
- Django==5.0.3
- donut-python==1.0.9
- fastapi==0.110.0
- ffmpy==0.3.2
- filelock==3.13.1
- fonttools==4.50.0
- frozenlist==1.4.1
- fsspec==2024.2.0
- gradio==4.22.0
- gradio_client==0.13.0
- h11==0.14.0
- httpcore==1.0.4
- httpx==0.27.0
- huggingface-hub==0.21.4
- idna==3.6
- importlib_resources==6.3.2
- Jinja2==3.1.3
- joblib==1.3.2
- jsonschema==4.21.1
- jsonschema-specifications==2023.12.1
- kiwisolver==1.4.5
- lightning-utilities==0.11.0
- markdown-it-py==3.0.0
- MarkupSafe==2.1.5
- matplotlib==3.8.3
- mdurl==0.1.2
- mpmath==1.3.0
- multidict==6.0.5
- multiprocess==0.70.16
- munch==4.0.0
- networkx==3.2.1
- nltk==3.8.1
- numpy==1.26.4
- orjson==3.9.15
- packaging==24.0
- pandas==2.2.1
- pillow==10.2.0
- pyarrow==15.0.2
- pyarrow-hotfix==0.6
- pydantic==2.6.4
- pydantic_core==2.16.3
- pydub==0.25.1
- Pygments==2.17.2
- pyparsing==3.1.2
- python-dateutil==2.9.0.post0
- python-multipart==0.0.9
- pytorch-lightning==1.6.4
- pytz==2024.1
- PyYAML==6.0.1
- referencing==0.34.0
- regex==2023.12.25
- requests==2.31.0
- rich==13.7.1
- rpds-py==0.18.0
- ruamel.yaml==0.18.6
- ruamel.yaml.clib==0.2.8
- ruff==0.3.3
- safetensors==0.4.2
- sconf==0.2.5
- semantic-version==2.10.0
- sentencepiece==0.2.0
- shellingham==1.5.4
- six==1.16.0
- sniffio==1.3.1
- sqlparse==0.4.4
- starlette==0.36.3
- sympy==1.12
- timm==0.9.16
- tokenizers==0.15.2
- tomlkit==0.12.0
- toolz==0.12.1
- torch==2.2.1
- torchmetrics==1.3.2
- torchvision==0.17.1
- tqdm==4.66.2
- transformers==4.39.0
- typer==0.9.0
- typing_extensions==4.10.0
- tzdata==2024.1
- urllib3==2.2.1
- uvicorn==0.29.0
- websockets==11.0.3
- xxhash==3.4.1
- yarl==1.9.4
- zss==1.2.0