Open rishabhgupta93 opened 11 months ago
any update on this i am also having same issue
I also meet this issue. i find if you huav three remarks, for example location,organization,person. the parameter return nn.CrossEntropyLoss()(pred.view(-1,3),target.view(-1)) view(-1,3) , this must be 3.
Hi @manikanthp
Thanks for sharing the repo.
I tried to run this code on my custom dataset. I have five classes, attaching the label-studio output file and converted file.
When I am trying to run the code. I am getting following error.
Traceback (most recent call last): File "F:\PyCharmProjects\LayoutLMTrial\main.py", line 35, in
train_loss = train_fn(dataload, model, optimizer)
File "F:\PyCharmProjects\LayoutLMTrial\engine.py", line 11, in trainfn
, loss = model(data)
File "F:\PyCharmProjects\LayoutLMTrial\venv\lib\site-packages\torch\nn\modules\module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, *kwargs)
File "F:\PyCharmProjects\LayoutLMTrial\venv\lib\site-packages\torch\nn\modules\module.py", line 1527, in _call_impl
return forward_call(args, kwargs)
File "F:\PyCharmProjects\LayoutLMTrial\trainer.py", line 34, in forward
loss = loss_fn(output,lables)
File "F:\PyCharmProjects\LayoutLMTrial\trainer.py", line 13, in loss_fn
return nn.CrossEntropyLoss()(pred.view(-1,4),target.view(-1))
File "F:\PyCharmProjects\LayoutLMTrial\venv\lib\site-packages\torch\nn\modules\module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, *kwargs)
File "F:\PyCharmProjects\LayoutLMTrial\venv\lib\site-packages\torch\nn\modules\module.py", line 1527, in _call_impl
return forward_call(args, **kwargs)
File "F:\PyCharmProjects\LayoutLMTrial\venv\lib\site-packages\torch\nn\modules\loss.py", line 1179, in forward
return F.cross_entropy(input, target, weight=self.weight,
File "F:\PyCharmProjects\LayoutLMTrial\venv\lib\site-packages\torch\nn\functional.py", line 3053, in cross_entropy
return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)
ValueError: Expected input batch_size (1280) to match target batch_size (1024).
The only change that I did in code was. Changing the classes from 4 to 5 in main.py as mentioned below:
model = ModelModule(5)
Can you please help me to fix this issue.
Thanks Rishabh Gupta Training_json_1.json Training_layoutLMV3_1.json