Kaiseem / DAR-UNet

[JBHI2022] A novel 3D unsupervised domain adaptation framework for cross-modality medical image segmentation
Apache License 2.0
40 stars 5 forks source link

关于运行代码遇到的两个问题 #7

Open W0215-git opened 1 year ago

W0215-git commented 1 year ago

作者您好!您的思路独具特色,效果也非常好,我尝试使用您公开的代码进行实验, 在观察您的issue后,我使用了您分享的阶段2数据集进行阶段2的实验,但是我遇到了如下两个问题: 1.您在dataloader3d文件的getitem函数返回的是 output = {'img': A_img, 'label': A_label}, 但是在模型训练的时候使用的是 A_img=train_A_data['A_img'].cuda() A_label=train_A_data['A_label'].cuda() 导致如下报错 Traceback (most recent call last): File "/home/lab312/PycharmProjects/DAR-UNet-main/stage_2_seg_train.py", line 132, in A_img=train_A_data['A_img'].cuda() KeyError: 'A_img' 请问这里是否应该改为 A_img=train_A_data['img'].cuda() A_label=train_A_data['label'].cuda() 2.我尝试将代码改为 A_img=train_A_data['img'].cuda() A_label=train_A_data['label'].cuda() 但是接下来遇到如下报错 Traceback (most recent call last): File "/home/lab312/PycharmProjects/MyWork/stage_2_seg_train.py", line 138, in for iteration, train_A_data in enumerate(train_A_loader): File "/opt/anaconda3/envs/uda/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 530, in next data = self._next_data() File "/opt/anaconda3/envs/uda/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1224, in _next_data return self._process_data(data) File "/opt/anaconda3/envs/uda/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1250, in _process_data data.reraise() File "/opt/anaconda3/envs/uda/lib/python3.9/site-packages/torch/_utils.py", line 457, in reraise raise exception NameError: Caught NameError in DataLoader worker process 0. Original Traceback (most recent call last): File "/opt/anaconda3/envs/uda/lib/python3.9/site-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loop data = fetcher.fetch(index) File "/opt/anaconda3/envs/uda/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 49, in fetch data = [self.dataset[idx] for idx in possibly_batched_index] File "/opt/anaconda3/envs/uda/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 49, in data = [self.dataset[idx] for idx in possibly_batched_index] File "/home/lab312/PycharmProjects/MyWork/utils/dataloader3d.py", line 78, in getitem seg_augmented = self.aug(image=A_img, mask=A_label) File "/home/lab312/.local/lib/python3.9/site-packages/volumentations/core/composition.py", line 60, in call data = tr(force_apply, self.targets, data) File "/home/lab312/.local/lib/python3.9/site-packages/volumentations/core/transforms_interface.py", line 117, in call data[k] = self.apply(v, params) File "/home/lab312/.local/lib/python3.9/site-packages/volumentations/augmentations/transforms.py", line 131, in apply return F.rescale_warp(img, scale, interpolation=self.interpolation) File "/home/lab312/.local/lib/python3.9/site-packages/volumentations/augmentations/functional.py", line 441, in rescale_warp return map_coordinates(img, coords, order=interpolation, mode=border_mode, cval=value) NameError: name 'border_mode' is not defined 这里主要体现在 seg_augmented = self.aug(image=A_img, mask=A_label) 这句代码报错 请问这里的问题该如何解决? 只有我注释掉self.aug(不进行增强),整个代码才可以运行,但是无法达到预期效果,寻求作者帮助

Kaiseem commented 1 year ago

你好,对于问题1,是的,这里应该是个bug,对于问题2,我印象里是volumentations-3d 这个库自带的bug,我当时去源代码里报错的地方指定了border_mode,在 volumentations/augmentations/functional.py 的 rescale_warp 这个函数里指定了 border_mode='constant', value=0

W0215-git commented 1 year ago

感谢回复!我在测试的时候遇到了一个问题 即加载您提供的ct2mr.pt模型,我发现里面有seg跟seg_ema模型,我就选择了seg模型 但是在加载模型的时候出错了,缺失了一些key RuntimeError: Error(s) in loading state_dict for DARUnet: Missing key(s) in state_dict: "L1_fromimg.conv_skip.1.weight", "L1_fromimg.conv_skip.1.bias", Unexpected key(s) in state_dict: "attn4.conv_encoder.0.weight", "attn4.conv_encoder.0.bias", 请问这是什么问题呢, 如果我在加载模型时,把strict改为false,那么预测出来的效果很差,是否是我用来预测的数据集路径的问题? 以下是我的测试代码,根据您提供的训练代码改的

import os os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'

from torch.utils.data import DataLoader from adabelief_pytorch import AdaBelief import torch import torch.nn as nn import numpy as np import torch.nn.functional as F import sys from utils import SEGDataset from models import DARUnet import time from utils.tta import PatchInferencer

class FocalLoss(nn.Module): def init(self, gamma=2, alpha=None,weight=None, ignore_index=255, size_average=True): super(FocalLoss, self).init() self.gamma = gamma self.size_average = size_average self.weight=weight

def forward(self, input: torch.Tensor, target: torch.Tensor):
    i = input
    t = target

    # Change the shape of input and target to B x N x num_voxels.
    i = i.view(i.size(0), i.size(1), -1)
    t = t.view(t.size(0), t.size(1), -1)

    # Compute the log proba.
    logpt = F.log_softmax(i, dim=1)
    # Get the proba
    pt = torch.exp(logpt)  # B,H*W or B,N,H*W

    if self.weight is not None:
        class_weight = torch.as_tensor(self.weight)
        class_weight = class_weight.to(i)

        at = class_weight[None, :, None]
        at = at.expand((t.size(0), -1, t.size(2)))
        logpt = logpt * at

    # Compute the loss mini-batch.
    weight = torch.pow(-pt + 1.0, self.gamma)
    loss = torch.mean(-weight * t * logpt, dim=-1)
    return loss.mean()

class DiceLoss(nn.Module): def init(self, n_classes): super(DiceLoss, self).init() self.n_classes = n_classes

def _one_hot_encoder(self, input_tensor):
    tensor_list = []
    for i in range(self.n_classes):
        temp_prob = input_tensor == i  # * torch.ones_like(input_tensor)
        tensor_list.append(temp_prob.unsqueeze(1))
    output_tensor = torch.cat(tensor_list, dim=1)
    return output_tensor.float()

def _dice_loss(self, score, target):
    target = target.float()
    smooth = 1e-5
    intersect = torch.sum(score * target)
    y_sum = torch.sum(target * target)
    z_sum = torch.sum(score * score)
    loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth)
    loss = 1 - loss
    return loss

def forward(self, inputs, target, weight=None, softmax=False):
    if softmax:
        inputs = torch.softmax(inputs, dim=1)
    if weight is None:
        weight = [1] * self.n_classes
    assert inputs.size() == target.size(), 'predict {} & target {} shape do not match'.format(inputs.size(), target.size())
    class_wise_dice = []
    loss = 0.0
    for i in range(0, self.n_classes):
        dice = self._dice_loss(inputs[:, i], target[:, i])
        class_wise_dice.append(1.0 - dice.item())
        loss += dice * weight[i]
    return loss / self.n_classes

def dice_loss_chill(output, gt): num = (output*gt).sum(dim=[2, 3, 4]) denom = output.sum(dim=[2, 3, 4]) + gt.sum(dim=[2, 3, 4]) + 0.001 return num, denom

import argparse parser = argparse.ArgumentParser() parser.add_argument('--name', type=str, default='experiment') parser.add_argument('--ckpt_path', type=str, default='datasets/checkpoints/ct2mr.pt') parser.add_argument('--train_dataroot', type=str, default='datasets/source2target_training_npy_multi_style') parser.add_argument('--val_dataroot', type=str, default='datasets/source_test_npy') parser.add_argument('--num_classes', type=int, default=5) parser.add_argument('--epoch_max', type=int, default=100)

opts = parser.parse_args()

if name == 'main': epoch_max=opts.epoch_max test_A_loader = DataLoader(dataset=SEGDataset(opts.val_dataroot, opts.num_classes), batch_size=1, shuffle=True, drop_last=True, num_workers=0, pin_memory=True)

USE_CUDA = torch.cuda.is_available()
device_ids = [0, 1]
device = torch.device("cuda" if USE_CUDA else "cpu")
netS=DARUnet()
# netS = torch.nn.DataParallel(netS, device_ids=device_ids)

state_dict = torch.load(opts.ckpt_path)
netS.load_state_dict(state_dict['seg'])

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
netS.to(device)

ipp = PatchInferencer(n_class=opts.num_classes,TTA=False)

netS.eval()
dices=[]
for val_data in test_A_loader:
    for k in val_data.keys():
        val_data[k] = val_data[k].cuda().detach()
    with torch.no_grad():
        # output = ipp(netS, val_data['A_img'])
        output = ipp(netS, val_data['img'])
        pred=F.one_hot(torch.argmax(output,1), opts.num_classes).permute(0,4,1,2,3)
        # gt=val_data['A_label']
        gt = val_data['label']
        num, denom = dice_loss_chill(pred,gt)
        d = (2 * num / denom)[:, 1:].mean().cpu().numpy()
        dices.append(d)
torch.cuda.empty_cache()
dices=np.mean(dices)

print(dices)
Kaiseem commented 1 year ago

同学你好,我重新convert了模型的权重,原来的权重有问题我没注意到,是原始代码保存下来的权重,你可以把这个链接下的checkpoints和datasets下载了解压在code的目录下运行,我测试了下稍微有一点区别(+-0.002),但应该没有问题

链接:https://pan.baidu.com/s/1wXDjC_zJ3B7G3buhAXgVJg?pwd=19i0 提取码:19i0

W0215-git commented 1 year ago

作者大佬你好,我测试了mr2ct一切正常!但是ct2mr的测试异常,很奇怪asd为inf,dice为0.04,求帮助!

Kaiseem commented 1 year ago

很奇怪,我这测试下来没有问题:

Dice per class: 0.93438125 0.9030991 0.8814075 0.92374444 Overall Dice: 0.910658061504364

ASD per class: 0.742716326513328 0.5162332531058162 0.6698428784888762 0.6620714135699539 Overall ASD: 0.6477159679194936

以防万一,我传了一份在我这测试的数据集,为了复现性,我都用了原始数据进行的测试

链接:https://pan.baidu.com/s/14b-ARlbZ7hnG6aO-fil9AA?pwd=picg 提取码:picg

W0215-git commented 1 year ago

感谢作者的耐心回复,我使用了您提供的数据,仍然是老样子,以下是我控制台输出 [0.07339802 0. 0. 0. ] [ 8.03247482 inf 21.37460618 17.53910189] image shape:(256, 256, 36) mask shape:(256, 256, 36) target_shape:[395, 395, 72] [0 1 2 3 4] torch.Size([1, 5, 72, 395, 395]) [0.15282708 0. 0. 0.09119039] [ 6.88473816 27.4157674 25.96430555 3.16656022] image shape:(320, 320, 30) mask shape:(320, 320, 30) target_shape:[435, 435, 60] [0 1 2 3 4] torch.Size([1, 5, 60, 435, 435]) [0.12263599 0. 0.01364036 0. ] [ 7.49903341 inf 16.78725337 9.64006033] image shape:(256, 256, 39) mask shape:(256, 256, 39) target_shape:[465, 465, 78] [0 1 2 3 4] torch.Size([1, 5, 78, 465, 465]) /home/ytt/miniconda3/envs/uda/lib/python3.8/site-packages/monai/metrics/surface_distance.py:163: UserWarning: the prediction of class 2 is all 0, this may result in nan/inf distance. warnings.warn(f"the prediction of class {c} is all 0, this may result in nan/inf distance.") [0.20619844 0.06618465 0. 0. ] [ 8.72714089 11.25044471 inf 57.80082477] Dice per class: 0.13876489 0.016546162 0.003410091 0.022797598 Overall Dice: 0.04537968337535858 ASD per class: 7.785846822169248 inf inf 22.03663680363034 Overall ASD: inf

Kaiseem commented 1 year ago

建议排查下环境因素,换个电脑的环境试试?

我现在这个环境: win11 python=3.7.3 pytorch==1.13.0 monai==1.0.1 pydicom==2.3.1 numpy==1.21.6

我这的控制台输出 image shape:(256, 256, 36) mask shape:(256, 256, 36) target_shape:[395, 395, 72] [0 1 2 3 4] torch.Size([1, 5, 72, 395, 395]) [0.94225276 0.91949344 0.89912534 0.9345378 ] [0.57352691 0.5621793 0.72125019 0.37196489] image shape:(320, 320, 30) mask shape:(320, 320, 30) target_shape:[435, 435, 60] [0 1 2 3 4] torch.Size([1, 5, 60, 435, 435]) [0.91267955 0.89565486 0.84346557 0.9006127 ] [0.77483473 0.56661782 0.80452874 0.46962353] image shape:(256, 256, 39) mask shape:(256, 256, 39) target_shape:[465, 465, 78] [0 1 2 3 4] torch.Size([1, 5, 78, 465, 465]) [0.93591267 0.90216845 0.8863868 0.93068665] [1.17235788 0.50188503 0.64777277 1.462603 ] image shape:(320, 320, 34) mask shape:(320, 320, 34) target_shape:[435, 435, 68] [0 1 2 3 4] torch.Size([1, 5, 68, 435, 435]) [0.94669104 0.8950796 0.8966523 0.92914045] [0.45004656 0.43425086 0.50581982 0.34409423] Dice per class: 0.934384 0.9030991 0.8814075 0.92374444 Overall Dice: 0.9106587767601013 ASD per class: 0.7426915185418835 0.5162332531058162 0.6698428784888762 0.6620714135699539 Overall ASD: 0.6477097659266325

W0215-git commented 1 year ago

非常感谢,已经解决这个问题,大佬您好,我还有一个问题,请问您对比的方法,为什么ASD只精确到小数点后一位,而自己的方法精确到小数点后两位呢

W0215-git commented 1 year ago

其他作者也是按照这个规则,我不明白这是为什么

Kaiseem commented 1 year ago

你好啊,这个我不确定是什么问题,我也才发现这个问题,我这个结果都是直接转抄别人的结果的,所以可能是前面有人突然把小数点后一位变到了两位,然后被我follow下来了

aeinkoupaei commented 12 months ago

感谢作者的耐心回复,我使用了您提供的数据,仍然是老样子,以下是我控制台输出 [0.07339802 0. 0. 0. ] [ 8.03247482 inf 21.37460618 17.53910189] image shape:(256, 256, 36) mask shape:(256, 256, 36) target_shape:[395, 395, 72] [0 1 2 3 4] torch.Size([1, 5, 72, 395, 395]) [0.15282708 0. 0. 0.09119039] [ 6.88473816 27.4157674 25.96430555 3.16656022] image shape:(320, 320, 30) mask shape:(320, 320, 30) target_shape:[435, 435, 60] [0 1 2 3 4] torch.Size([1, 5, 60, 435, 435]) [0.12263599 0. 0.01364036 0. ] [ 7.49903341 inf 16.78725337 9.64006033] image shape:(256, 256, 39) mask shape:(256, 256, 39) target_shape:[465, 465, 78] [0 1 2 3 4] torch.Size([1, 5, 78, 465, 465]) /home/ytt/miniconda3/envs/uda/lib/python3.8/site-packages/monai/metrics/surface_distance.py:163: UserWarning: the prediction of class 2 is all 0, this may result in nan/inf distance. warnings.warn(f"the prediction of class {c} is all 0, this may result in nan/inf distance.") [0.20619844 0.06618465 0. 0. ] [ 8.72714089 11.25044471 inf 57.80082477] Dice per class: 0.13876489 0.016546162 0.003410091 0.022797598 Overall Dice: 0.04537968337535858 ASD per class: 7.785846822169248 inf inf 22.03663680363034 Overall ASD: inf

I faced exactly the same problem for ct2mr model. I would be grateful if you could tell me how you solved this problem.

aeinkoupaei commented 11 months ago

Hello, I hope this message finds you well. I came across the issue you faced while evaluating the ct2mr model. I've encountered the exact same problem. Here are my results for the seg_ct2mr model: Dice per class: 0.0788, 0.0374, 0.0137, 0.0127 Overall Dice: 0.0356 ASD per class: 9.4453, inf, inf, inf Overall ASD: inf I was wondering if you managed to find a solution or workaround for this problem. I'd really appreciate any guidance or feedback you can provide on how you tackled it. Thank you in advance for your time and help. @W0215-git

W0215-git commented 6 months ago

Hello, I hope this message finds you well. I came across the issue you faced while evaluating the ct2mr model. I've encountered the exact same problem. Here are my results for the seg_ct2mr model: Dice per class: 0.0788, 0.0374, 0.0137, 0.0127 Overall Dice: 0.0356 ASD per class: 9.4453, inf, inf, inf Overall ASD: inf I was wondering if you managed to find a solution or workaround for this problem. I'd really appreciate any guidance or feedback you can provide on how you tackled it. Thank you in advance for your time and help. @W0215-git

sorry for the delay reply. I just transfer my experiment to another computer and re-test, no other trick.I wish it helps. you can download the following link that author update 链接:https://pan.baidu.com/s/14b-ARlbZ7hnG6aO-fil9AA?pwd=picg 提取码:picg