analogdevicesinc / ai8x-training

Model Training for ADI's MAX78000 and MAX78002 Edge AI Devices
Apache License 2.0
92 stars 80 forks source link

Faceid training and evaluation: tensor size mismatch when batch size > 1 #184

Closed lx2u16 closed 1 year ago

lx2u16 commented 2 years ago

Hi, I am currently trying to run the training and evaluation scripts of faceid. They only work when I set batch size 1, and other sizes will give me this tensor size mismatch error.

e.g. with batch size 100

Traceback (most recent call last): File "train.py", line 1794, in main() File "train.py", line 438, in main return evaluate_model(model, criterion, test_loader, pylogger, activations_collectors, File "train.py", line 1495, in evaluatemodel top1, , , , mAP = test(test_loader, model, criterion, loggers, activations_collectors, File "train.py", line 958, in test top1, top5, vloss, APs, mAP = _validate(test_loader, model, criterion, loggers, args) File "train.py", line 1163, in _validate classerr.add(output.data.permute(0, 2, 3, 1).flatten(start_dim=0, File "/opt/conda/lib/python3.8/site-packages/torchnet/meter/msemeter.py", line 21, in add self.sesum += torch.sum((output - target) ** 2) RuntimeError: The size of tensor a (512) must match the size of tensor b (51200) at non-singleton dimension 1

Could you please give me some help on how to fix this? Thanks.

justp7 commented 1 year ago

I also encountered the same problem when running, may I ask if there is any progress in updating the step?

justp7 commented 1 year ago

By modifying lines 779 and 1111 of the train.py file, I change classerr.add(output.data.permute(0, 2, 3,1).flatten(start_dim=0,end_dim=2),target.flatten()) to classerr.add(output.data, target) , Now it's working properly! I don't know if this is correct, but it works! With best wishes.

github-actions[bot] commented 1 year ago

This issue has been marked stale because it has been open for over 30 days with no activity. It will be closed automatically in 10 days unless a comment is added or the "Stale" label is removed.

ermanok commented 1 year ago

Commit #187 fixes that issue. thanks for reporting.