Open cramraj8 opened 1 year ago
Same issue with me -
Though I am using Pytorch 1.7+cu110, timm==0.4.5
Errors with trocr-base and trocr-large -
At the end, I am getting this error for all the cases (trocr-small, trocr-base, trocr-large)
File "/workspace/unilm_abhishek/trocr/deit.py", line 109, in forward_features
if self.dist_token is None:
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 778, in __getattr__
raise ModuleAttributeError("'{}' object has no attribute '{}'".format(
torch.nn.modules.module.ModuleAttributeError: 'AdaptedVisionTransformer' object has no attribute 'dist_token'
errors with trocr-small
I had to add this line self.distilled = kwargs.pop('distilled')
I got key mismatch error for "dist_token", "head_dist.weight", "head_dist.bias"
for trocr-small because pre-trained models are having - "dist_token", "head_dist.weight", "head_dist.bias" keys
I handled this by updating this line
https://github.com/microsoft/unilm/blob/1111feedba3cf3612a69aaa3d8546942d07f9800/trocr/deit.py#L227
to
model.load_state_dict(checkpoint["model"], strict=False)
Then I got
File "/workspace/unilm_abhishek/trocr/trocr_models.py", line 236, in build_model
missing_keys, unexpected_keys = decoder.load_state_dict(
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for TransformerDecoder:
size mismatch for embed_tokens.weight: copying a param with shape torch.Size([50265, 768]) from checkpoint, the shape in current model is torch.Size([50265, 256]).
NOTE:-
self.pretrained_cfg = kwargs.pop('pretrained_cfg')
self.pretrained_cfg_overlay = kwargs.pop('pretrained_cfg_overlay')
I switched to timm==0.5.4 and the following error got resolved -
File "/workspace/unilm_abhishek/trocr/deit.py", line 109, in forward_features
if self.dist_token is None:
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 778, in __getattr__
raise ModuleAttributeError("'{}' object has no attribute '{}'".format(
torch.nn.modules.module.ModuleAttributeError: 'AdaptedVisionTransformer' object has no attribute 'dist_token'
Different trained TrOCR models require custom arg changes in the script Model I am using TrOCR:
The problem arises when using:
When I load trocr-base-str.pt model, nothing breaks. But when I load trocr-small-stage1.pt or trocr-small-printed.pt, the mode throws error. Even by adding below arguments in the AdaptedVisionTransformer class as below, the model loading still throws unexpected keys found error.
Error I am still getting after the above change,
To Reproduce Steps to reproduce the behavior: