MontaEllis / Pytorch-Medical-Segmentation

This repository is an unoffical PyTorch implementation of Medical segmentation in 2D and 3D.
MIT License
864 stars 196 forks source link

3D数据集在loss函数上出现维度匹配 #50

Closed smallkaka closed 1 year ago

smallkaka commented 1 year ago

作者你好,首先很感谢您在github上提供的代码。 下面是我在训练3D CT数据出现的问题:

训练时在运行到main.py里的260行,也就是loss=criterion_ce(outputs,y.argmax(dim=1)) +criterion_dice(outputs,y.argmax(dim=1))出错,我进去看后,发现是predict和target维度不一致,以图片大小为646496,batchs=4,channel=2为例,predict为(4,2,64,64,96),target为(4,128,64,96),缺少channel维度,使得dice损失函数无法运行,同时self._one_hot_encoder函数使得第二个数翻倍,也就是叠加了,我加入target = torch.unsqueeze(target,1)这行代码后,成功进行训练,但tensorboard中显示loss在缓慢下降,但dice直接下降到0。

D:\ProgramData\miniconda3\envs\pytorch\python.exe E:/python_learning/Pytorch-Medical-Segmentation-master/main.py epoch:1 Batch: 0/5 epoch 1

Traceback (most recent call last): File "E:\python_learning\Pytorch-Medical-Segmentation-master\main.py", line 582, in train() File "E:\python_learning\Pytorch-Medical-Segmentation-master\main.py", line 265, in train loss = criterion_ce(outputs, y.argmax(dim=1)) + criterion_dice(outputs, y.argmax(dim=1)) File "D:\ProgramData\miniconda3\envs\pytorch\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl return forward_call(*input, **kwargs) File "E:\python_learning\Pytorch-Medical-Segmentation-master\loss_function.py", line 141, in forward assert inputs.size() == target.size(), 'predict & target shape do not match' AssertionError: predict & target shape do not match

Process finished with exit code 1

非常感谢您的回复

MontaEllis commented 1 year ago

可能是你的target数据放错了?

smallkaka commented 1 year ago

2023-02-20_090359 作者您的意思我不太明白,数据是进行二分类,并且也是label也是0,1分布的。我将代码进行debug发现这个情况,但是我不清楚问题出在哪里。 这里是进入DIceloss经onehot编码之后的两个张量的shape 2023-02-20_170826 后面就会报维度不匹配的问题

yuijn456 commented 1 year ago

same question

yuijn456 commented 1 year ago

我将target.reshape([2,2,32,32,32])后成功运行

smallkaka commented 1 year ago

运行之后,可能dice一直为0,我也出现你在issue的问题了

y2kcoding commented 1 year ago

运行之后,可能dice一直为0,我也出现你在issue的问题了

您好,dice为0的情况我也出现了,您是怎么解决这个问题的呢?

JueWuANDMiWang commented 1 year ago

同样的问题,维度不匹配。我将main.py里的loss = criterion_ce(outputs, y.argmax(dim=1)) + criterion_dice(outputs, y.argmax(dim=1))改为loss = criterion_ce(outputs, y) + criterion_dice(outputs, y)。 将loss_function.py里108行的 tensor_list = [] for i in range(self.n_classes): temp_prob = input_tensor == i * torch.ones_like(input_tensor) tensor_list.append(temp_prob) output_tensor = torch.cat(tensor_list, dim=1) 改为output_tensor = torch.ones_like(input_tensor),可以解决问题,但dice等于0。

wzj9846 commented 1 year ago

我将target.reshape([2,2,32,32,32])后成功运行 666

JueWuANDMiWang commented 1 year ago

请问dice=0有谁解决了吗?

BelieferQAQ commented 1 year ago

作者你好,首先很感谢您在github上提供的代码。 下面是我在训练3D CT数据出现的问题:

训练时在运行到main.py里的260行,也就是loss=criterion_ce(outputs,y.argmax(dim=1)) +criterion_dice(outputs,y.argmax(dim=1))出错,我进去看后,发现是predict和target维度不一致,以图片大小为64_64_96,batchs=4,channel=2为例,predict为(4,2,64,64,96),target为(4,128,64,96),缺少channel维度,使得dice损失函数无法运行,同时self._one_hot_encoder函数使得第二个数翻倍,也就是叠加了,我加入target = torch.unsqueeze(target,1)这行代码后,成功进行训练,但tensorboard中显示loss在缓慢下降,但dice直接下降到0。

D:\ProgramData\miniconda3\envs\pytorch\python.exe E:/python_learning/Pytorch-Medical-Segmentation-master/main.py epoch:1 Batch: 0/5 epoch 1

Traceback (most recent call last): File "E:\python_learning\Pytorch-Medical-Segmentation-master\main.py", line 582, in train() File "E:\python_learning\Pytorch-Medical-Segmentation-master\main.py", line 265, in train loss = criterion_ce(outputs, y.argmax(dim=1)) + criterion_dice(outputs, y.argmax(dim=1)) File "D:\ProgramData\miniconda3\envs\pytorch\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl return forward_call(*input, **kwargs) File "E:\python_learning\Pytorch-Medical-Segmentation-master\loss_function.py", line 141, in forward assert inputs.size() == target.size(), 'predict & target shape do not match' AssertionError: predict & target shape do not match

Process finished with exit code 1

非常感谢您的回复

Hello, do you correspond to the specific versions of Python, Torch, and TorchIO as required by the authors?I also encountered this issue, but I'd like to know why your channel is 2. My dataset is CT, so the channel should be 1, right?

knowleton commented 1 year ago

请问dice=0有谁解决了吗?

Did anyone solve dice = 0? Did you solve it?