Senyh / UCMT

[IJCAI 2023] Co-training with High-Confidence Pseudo Labels for Semi-supervised Medical Image Segmentation
35 stars 2 forks source link

How to select the uncertainty map topk small batch on the 3d heart volume data? #6

Closed zhuji423 closed 11 months ago

zhuji423 commented 11 months ago

你好,非常高兴看见你们的工作被ijcai接受,我复现了你们的代码,再ISIC数据集上面可以达到你们的效果,甚至略微超过了0.1个点,证明你们的方法是可行有效的。

但我想在3d的ct数据上面实现UCMT,关键的难点在于unfold 函数pytorch官方只能实现4d的张量,对于ct数据而言,训练时的数据维度是五维的[batch,class,depth,width,height]。我无法实现umix的这一步,请问我应该使用哪一个函数来对3d的ct数据进行不确定度最高的块的选择呢。我在这一步卡了很长时间,非常感谢你们的回复。

unfolds  = torch.nn.Unfold(kernel_size=(h, w), stride=s).to(device)
folds = torch.nn.Fold(output_size=(args.image_size, args.image_size), kernel_size=(h, w), stride=s).to(device)
x11 = unfolds(uncertainty_map11)  # B x C*kernel_size[0]*kernel_size[1] x L  8 256 256 
x11 = x11.view(B, 1, h, w, -1)  # B x C x h x w x L
x11_mean = torch.mean(x11, dim=(1, 2, 3))  # B x L
_, x11_max_index = torch.sort(x11_mean, dim=1, descending=True)  # B x L B x L
# for student 2
x22 = unfolds(uncertainty_map22)  # B x C*kernel_size[0]*kernel_size[1] x L
x22 = x22.view(B, 1, h, w, -1)  # B x C x h x w x L
x22_mean = torch.mean(x22, dim=(1, 2, 3))  # B x L
_, x22_max_index = torch.sort(x22_mean, dim=1, descending=True)  # B x L B x L
img_unfold = unfolds(imageA1).view(B, C, h, w, -1)  # B x C x h x w x L
lab_unfold = unfolds(label.float()).view(B, 1, h, w, -1)  # B x C x h x w x L
for i in range(B):## 对8张图片进行操作
    img_unfold[i, :, :, :, x11_max_index[i, :topk]] = img_unfold[i, :, :, :, x22_max_index[i, -topk:]]
    img_unfold[i, :, :, :, x22_max_index[i, :topk]] = img_unfold[i, :, :, :, x11_max_index[i, -topk:]]
    lab_unfold[i, :, :, :, x11_max_index[i, :topk]] = lab_unfold[i, :, :, :, x22_max_index[i, -topk:]]
    lab_unfold[i, :, :, :, x22_max_index[i, :topk]] = lab_unfold[i, :, :, :, x11_max_index[i, -topk:]]
image2 = folds(img_unfold.view(B, C*h*w, -1))
label2 = folds(lab_unfold.view(B, 1*h*w, -1))

image

Senyh commented 11 months ago

感谢你对我们工作的关注。 你可以用unfoldNd库https://github.com/f-dangel/unfoldNd/tree/main

unfolds = unfoldNd.UnfoldNd(kernel_size=(d, h, w), stride=(d, h, w)).to(device) folds = unfoldNd.FoldNd(output_size=(args.image_size[0], args.image_size[1], args.image_size[2]), kernel_size=(d, h, w), stride=(d, h, w)).to(device)

zhuji423 commented 11 months ago

非常感谢您的回复,我去尝试一下