limhoyeon / ToothGroupNetwork

3D Dental surface segmentation with Tooth Group Network
175 stars 41 forks source link

Changing the from batch_size=1 to anything else in TSegNet causes an error. #8

Closed M-Alsouqi closed 1 year ago

M-Alsouqi commented 1 year ago

Hello. You've done a great job with this repository. Unfortunately I wasn't able to train TGNet due to VRam constraints. However, I tried training TSegNet since it doesn't take much VRam.

Training TSegNet with batch_size=1 is working fine. However when I tried to change the batch_size into a bigger number, an error occurred

I assume that the loss for TSegNet is not coded to support a batch_size that doesn't equal to 1? Does it only accept singular datapoints at a time?


  File "C:\Users\User\Desktop\ToothGroupNetwork\start_train.py", line 55, in <module>
    runner(config, model)
  File "C:\Users\User\Desktop\ToothGroupNetwork\runner.py", line 57, in runner
    trainner.run()
  File "C:\Users\User\Desktop\ToothGroupNetwork\trainer.py", line 118, in run
    self.train(epoch, train_data_loader)
  File "C:\Users\User\Desktop\ToothGroupNetwork\trainer.py", line 38, in train
    loss = self.model.step(batch_idx, batch_item, "train")
  File "C:\Users\Desktop\ToothGroupNetwork\models\tsegnet_model.py", line 49, in step
    gt_centroid_coords, gt_centroid_exists = ou.seg_label_to_cent(batch_item["feat"][:,:3,:], batch_item["gt_seg_label"])
  File "C:\Users\User\Desktop\ToothGroupNetwork\ops_utils.py", line 173, in seg_label_to_cent
    gt_coords = gt_coords.view(-1,3)
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.```
M-Alsouqi commented 1 year ago

Oh I just noticed the thing that you have written in the Training section. I will look into it then.