CSCYQJ / MICCAI23-ProtoContra-SFDA

This is the official code of MICCAI23 paper "Source-Free Domain Adaptation for Medical Image Segmentation via Prototype-Anchored Feature Alignment and Contrastive Learning"
26 stars 6 forks source link

There are some questions about training source domain model #14

Open 1187056859 opened 10 months ago

1187056859 commented 10 months ago

When training the source domain model, I see that your config file says input_dim=3, and then the U-Net output is [batch_size,n_classes, H, W], does it mean that you input 3 slices to the network, and then only use 1 slice label to calculate the loss? Or I would like to ask what your data looks like, is the image of each .npz file 3 slices, and then the label is 1 slice? Thanks for your reply!

BarY7 commented 6 months ago

A bit late but for anyone else struggling with the data preprocessing, there is a notebook in the commit history that gives some insight

As for your question the authors take the previous and next slice as input to the network, as shown in this snippet from the notebook:

for img_name in train_img_list:
    label_name = 'label'+img_name[3:]
    img = sitk.ReadImage(os.path.join(img_dir,img_name))
    img = sitk.GetArrayFromImage(img)
    label = sitk.ReadImage(os.path.join(label_dir,label_name))
    label = sitk.GetArrayFromImage(label)
    # mean,std = np.mean(img),np.std(img)
    # img = z_score(img,-125,275)
    # min_v,max_v = img.min(),img.max()
    # img = (img - min_v) / (max_v-min_v)
    # img = min_max_normalization(img,-1.0582,4.3029)
    img = min_max_normalization(img,-125,275)
    img = np.flip(img.transpose((1,2,0)),axis=0)
    label = np.flip(label.transpose((1,2,0)),axis=0)
    for index in range(img.shape[2]):
        if index==0:
            img_ = img[:,:,[0,0,1]]
        elif index==img.shape[2]-1:
            img_ = img[:,:,[index-1,index,index]]
        else:
            img_ = img[:,:,index-1:index+2]
        seg = label[:,:,index]
        assert img_.shape[2]==3
        np.savez(os.path.join(site_dir,'train','{}_{}.npz'.format(img_name[3:7],index)),image = img_,label = seg)