shabie / docformer

Implementation of DocFormer: End-to-End Transformer for Document Understanding, a multi-modal transformer based architecture for the task of Visual Document Understanding (VDU)
MIT License
253 stars 40 forks source link

Error When Following the Usage Instructions #27

Closed ynusinovich closed 2 years ago

ynusinovich commented 2 years ago

I tried following the usage instructions you posted on a sample .jpg image of a receipt. Every time I run it, I get an error saying, "RuntimeError: Expected 4-dimensional input for 4-dimensional weight [64, 3, 7, 7], but got 3-dimensional input of size [3, 384, 500] instead". How do I fix that?

Full code:

import pytesseract
import sys 
sys.path.extend(['docformer/src/docformer/'])
import modeling, dataset
from transformers import BertTokenizerFast

config = {
  "coordinate_size": 96,
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "image_feature_pool_shape": [7, 7, 256],
  "intermediate_ff_size_factor": 4,
  "max_2d_position_embeddings": 1000,
  "max_position_embeddings": 512,
  "max_relative_positions": 8,
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "shape_size": 96,
  "vocab_size": 30522,
  "layer_norm_eps": 1e-12,
}

fp = "images/data_sample.jpg"

tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
encoding = dataset.create_features(fp, tokenizer)

pytesseract.pytesseract.tesseract_cmd = r'‪C:\Program Files\Tesseract-OCR\tesseract.exe'

feature_extractor = modeling.ExtractFeatures(config)
docformer = modeling.DocFormerEncoder(config)

v_bar, t_bar, v_bar_s, t_bar_s = feature_extractor(encoding)
output = docformer(v_bar, t_bar, v_bar_s, t_bar_s)  # shape (1, 512, 768)

Full error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Input In [3], in <module>
     31 feature_extractor = modeling.ExtractFeatures(config)
     32 docformer = modeling.DocFormerEncoder(config)
---> 34 v_bar, t_bar, v_bar_s, t_bar_s = feature_extractor(encoding)
     35 output = docformer(v_bar, t_bar, v_bar_s, t_bar_s)

File ~\anaconda3\envs\docformer_env\lib\site-packages\torch\nn\modules\module.py:1102, in Module._call_impl(self, *input, **kwargs)
   1098 # If we don't have any hooks, we want to skip the rest of the logic in
   1099 # this function, and just call forward.
   1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102     return forward_call(*input, **kwargs)
   1103 # Do not call functions when jit is used
   1104 full_backward_hooks, non_full_backward_hooks = [], []

File ~\Documents\Projects\docformer_implementation\docformer/src/docformer\modeling.py:512, in ExtractFeatures.forward(self, encoding)
    509 x_feature = encoding['x_features']
    510 y_feature = encoding['y_features']
--> 512 v_bar = self.visual_feature(image)
    513 t_bar = self.language_feature(language)
    515 v_bar_s, t_bar_s = self.spatial_feature(x_feature, y_feature)

File ~\anaconda3\envs\docformer_env\lib\site-packages\torch\nn\modules\module.py:1102, in Module._call_impl(self, *input, **kwargs)
   1098 # If we don't have any hooks, we want to skip the rest of the logic in
   1099 # this function, and just call forward.
   1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102     return forward_call(*input, **kwargs)
   1103 # Do not call functions when jit is used
   1104 full_backward_hooks, non_full_backward_hooks = [], []

File ~\Documents\Projects\docformer_implementation\docformer/src/docformer\modeling.py:48, in ResNetFeatureExtractor.forward(self, x)
     47 def forward(self, x):
---> 48     x = self.resnet50(x)
     49     x = self.conv1(x)
     50     x = self.relu1(x)

File ~\anaconda3\envs\docformer_env\lib\site-packages\torch\nn\modules\module.py:1102, in Module._call_impl(self, *input, **kwargs)
   1098 # If we don't have any hooks, we want to skip the rest of the logic in
   1099 # this function, and just call forward.
   1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102     return forward_call(*input, **kwargs)
   1103 # Do not call functions when jit is used
   1104 full_backward_hooks, non_full_backward_hooks = [], []

File ~\anaconda3\envs\docformer_env\lib\site-packages\torch\nn\modules\container.py:141, in Sequential.forward(self, input)
    139 def forward(self, input):
    140     for module in self:
--> 141         input = module(input)
    142     return input

File ~\anaconda3\envs\docformer_env\lib\site-packages\torch\nn\modules\module.py:1102, in Module._call_impl(self, *input, **kwargs)
   1098 # If we don't have any hooks, we want to skip the rest of the logic in
   1099 # this function, and just call forward.
   1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102     return forward_call(*input, **kwargs)
   1103 # Do not call functions when jit is used
   1104 full_backward_hooks, non_full_backward_hooks = [], []

File ~\anaconda3\envs\docformer_env\lib\site-packages\torch\nn\modules\conv.py:446, in Conv2d.forward(self, input)
    445 def forward(self, input: Tensor) -> Tensor:
--> 446     return self._conv_forward(input, self.weight, self.bias)

File ~\anaconda3\envs\docformer_env\lib\site-packages\torch\nn\modules\conv.py:442, in Conv2d._conv_forward(self, input, weight, bias)
    438 if self.padding_mode != 'zeros':
    439     return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
    440                     weight, bias, self.stride,
    441                     _pair(0), self.dilation, self.groups)
--> 442 return F.conv2d(input, weight, bias, self.stride,
    443                 self.padding, self.dilation, self.groups)

RuntimeError: Expected 4-dimensional input for 4-dimensional weight [64, 3, 7, 7], but got 3-dimensional input of size [3, 384, 500] instead
uakarsh commented 2 years ago

Have a look here: https://github.com/uakarsh/docformer/blob/master/examples/DocFormer_for_MLM.ipynb

The error is because, the entity is not batched (i.e having a shape of (....), rather than (batch_size,....)

ynusinovich commented 2 years ago

@uakarsh Thank you for your help! Does this mean that the Usage section of the README can't actually be used? I was trying to do a demo of it to my study group. I tried encoding['resized_scaled_img'] = encoding['resized_scaled_img'].unsqueeze(0) to add a batch size of 1, but that didn't work either.

uakarsh commented 2 years ago

It can be used, we just need to pass an argument, add_batch_dim=True, in dataset.create_features function.

uakarsh commented 2 years ago

The thing, which you did also won't work, because there are more than just image features, i.e you need to unsqueeze the other features as well. I have updated the readme, hope it helps

ynusinovich commented 2 years ago

Thank you so much, it runs now! Unsqueezing each feature also works for me, but add_batch_dimis more straightforward. Are there any examples of followup steps (i.e., what the resulting tensor means in terms of the input image)? I can't find that in the README and examples.

uakarsh commented 2 years ago

Maybe, you can have a look at the notebook, which I shared previously. In that notebook, you can go through the DocFormerForMLM class, and look at the forward method there. I would briefly describe it here:

All the shapes are mentioned as per the default configuration

  1. The self.embeddings, are responsible for encoding the spatial features of the bounding boxes (size -> (512,768)
  2. The self.resent, is responsible for extracting the image feature (size -> (512, 768)
  3. The self.lang_emb, is responsible for the language feature extraction from the words of the bounding boxes (size -> (512,768)
  4. The self.encoder, calculates the attention and forward propagates it (size -> (512,768)

And then, for downstream task, the linear layers are attached. Hope it helps.

ynusinovich commented 2 years ago

Ok, understood, thank you very much for your help. I'll close the issue since the example runs!