Closed effort1121 closed 11 months ago
Please check the dimensions of your onehotmask before the permute. It should be (B, H, W, Class) after F.one_hot. The output of the permute function should be (B, C, H, W). Also please check if you set the correct class number.
运行 python train.py Traceback (most recent call last): File "train.py", line 150, in
main(args)
File "train.py", line 146, in main
solver.train(model, train_loader, val_loader,exp_id+1, num_epochs=args.epochs)
File "/root/autodl-tmp/DconnNet-main/solver.py", line 147, in train
loss_main = self.loss_func(output, y)
File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/root/autodl-tmp/DconnNet-main/connect_loss.py", line 172, in forward
loss = self.multi_class_forward(c_map, target)
File "/root/autodl-tmp/DconnNet-main/connect_loss.py", line 184, in multi_class_forward
onehotmask = onehotmask.permute(0,3,1,2)
RuntimeError: number of dims don't match in permute