YazhouZhu19 / RPT

[MICCAI 2023] Few-Shot Medical Image Segmentation via a Region-enhanced Prototypical Transformer
36 stars 5 forks source link

BUG in class TrainDataset(Dataset) #16

Closed LeeX-Bruce closed 3 months ago

LeeX-Bruce commented 5 months ago

Hello author. I found a bug in the datasets.py file with settings that don't match the description in the paper.

The __getitem__ method in class TrainDataset, which should get the index of the slices of this patient that contain theexclude_label in case exclude_label is passed in, has the following original code:

exclude_idx = np.full(gt.shape[0], True, dtype=bool)
for i in range(len(self.exclude_label)):
    exclude_idx = exclude_idx & (np.sum(gt == self.exclude_label[i], axis=(1, 2)) > 0)
    print(f'{i}_exclude_idx: {idx[exclude_idx]}')
exclude_idx = idx[exclude_idx]

Its does not achieve the above effect. If my exclude_label is [1, 2, 3, 4], then the exclude_idx I get in patient 38 and the sli_idx I get later is shown below: image

Here is the code I corrected(Probably?):

exclude_idx = np.full(gt.shape[0], False, dtype=bool)
for i in range(len(self.exclude_label)):
    exclude_idx = exclude_idx | (np.sum(gt == self.exclude_label[i], axis=(1, 2)) > 0)
exclude_idx = idx[exclude_idx]

The result is as follows: image

YazhouZhu19 commented 3 months ago

Hi, LeeX-Bruce

We are grateful that you points out bugs in the datasets.py file. We directly employ the dataloader files from the repository https://github.com/ZJLAB-AMMI/Q-Net/tree/main/dataloaders.

We will update the dataloader code files according to your comments

LeeX-Bruce commented 3 months ago

Okay, thanks for the reply.