Zyun-Y / DconnNet

Codes for CVPR2023 paper "Directional Connectivity-based Segmentation of Medical Images"
137 stars 7 forks source link

训练报错 #19

Closed effort1121 closed 11 months ago

effort1121 commented 11 months ago

运行 python train.py Traceback (most recent call last): File "train.py", line 150, in main(args) File "train.py", line 146, in main solver.train(model, train_loader, val_loader,exp_id+1, num_epochs=args.epochs) File "/root/autodl-tmp/DconnNet-main/solver.py", line 147, in train loss_main = self.loss_func(output, y) File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl result = self.forward(*input, **kwargs) File "/root/autodl-tmp/DconnNet-main/connect_loss.py", line 172, in forward loss = self.multi_class_forward(c_map, target) File "/root/autodl-tmp/DconnNet-main/connect_loss.py", line 184, in multi_class_forward onehotmask = onehotmask.permute(0,3,1,2) RuntimeError: number of dims don't match in permute

Zyun-Y commented 11 months ago

Please check the dimensions of your onehotmask before the permute. It should be (B, H, W, Class) after F.one_hot. The output of the permute function should be (B, C, H, W). Also please check if you set the correct class number.