Herschel555 / CAML

repository for CAML: Correlation-aware mutual learning for semi-supervised medical image segmentation
32 stars 1 forks source link

Downloading Data #3

Open RRouhi opened 11 months ago

RRouhi commented 11 months ago

Thank you for sharing your interesting work and congrats on your paper acceptance for MICCAI23. Would you guide on how we can apply the method to our custom dataset?

Thank you!

Herschel555 commented 11 months ago

Hi, you can modify the codes for dataset in https://github.com/Herschel555/CAML/blob/master/code/dataloaders/dataset.py#L89 to return the processed data with transform in the same form (image, label) as the original dataset.

RRouhi commented 11 months ago

Thank you for your response. Is it needed to feed h5 files into the networks? I have nifti images in dataset. Should I convert the nifit images into h5 format? Thank you in advance for your guidance.

Herschel555 commented 11 months ago

Hi, you can convert the nifti images into h5 format and use the original dataset https://github.com/Herschel555/CAML/blob/master/code/dataloaders/dataset.py#L89, or you can modify the original dataset to read and process nifti images.

RRouhi commented 11 months ago

Thank you. Would you help to troubleshoot the error explained below? Thank you. I have a training and test dataset including 30 and 10 of nifti images images. I stored them in folders train and test, each of them included folders origs and masks including original images and their corresponding labels. The labels are 1, 2 and 0 regarding left and right hippocampus and the background. So, I set labelnum 3 in the command. The list of train and test images are generated in the code below (revised dataset class as you mentioned), and I ran the command python, code/train_caml.py --labelnum 3 --gpu 0 --batch_size 4 --seed 1337, but got an error attached below, I printed the idx to see the error reason, but I have no idea why I get idx 39:

class LAHeart(Dataset): """LA Dataset"""

def __init__(self, base_dir=None, split='train', num=None, transform=None, with_idx=False):

    self._base_dir = 'data/LA'
    self.transform = transform
    self.sample_list = []
    self.with_idx = with_idx

    if split == 'train':
        self.image_dir = os.path.join(self._base_dir, 'my_Training Set', 'train', 'origs')
        self.mask_dir = os.path.join(self._base_dir, 'my_Training Set', 'train', 'masks')
    elif split == 'test':
        self.image_dir = os.path.join(self._base_dir, 'my_Training Set', 'test', 'origs')  
        self.mask_dir = os.path.join(self._base_dir, 'my_Training Set', 'test', 'masks')  

    self.image_list = [filename for filename in os.listdir(self.image_dir) if filename.endswith('.nii.gz')]
    self.mask_list = [filename for filename in os.listdir(self.mask_dir) if filename.endswith('.nii.gz')]
    if num is not None:
        self.image_list = self.image_list[:num]
        self.mask_list = self.mask_list[:num]
    print("Total {} samples".format(len(self.image_list)))

def __len__(self):
    return len(self.image_list)

def __getitem__(self, idx):
    print("idx is =", idx)
    image_name = self.image_list[idx]
    mask_name = self.mask_list[idx]

    image_path = os.path.join(self.image_dir, image_name)
    mask_path = os.path.join(self.mask_dir, mask_name)

    image_data = nib.load(image_path).get_fdata()
    print("*****************", image_path,'\n') 
    print("idx is =", idx)
    label_data = nib.load(mask_path).get_fdata()  

    sample = {'image': image_data, 'label': label_data}

    if self.transform:
        sample = self.transform(sample)

    if self.with_idx:
        sample['idx'] = idx

    return sample

ERROR:

(caml) **:~/miniconda3/envs/caml/CAML$ python code/train_caml.py --labelnum 3 --gpu 0 --batch_size 4 --seed 1337 Namespace(dataset_name='LA', root_path='./data/LA', exp='CAML', model='caml3d_v1', max_iteration=15000, max_samples=80, labeled_bs=2, batch_size=4, base_lr=0.01, deterministic=1, labelnum=3, seed=1337, gpu='0', lamda=0.5, consistency=1, consistency_o=0.05, consistency_rampup=40.0, temperature=0.1, memory_num=256, embedding_dim=64, num_filtered=12800) Total 20 samples 1 itertations per epoch 0%| | 0/15001 [00:00<?, ?it/s]idx is = 2 ***** data/LA/my_Training Set/train/origs/12m_146a.nii.gz

idx is = 2 idx is = 1 ***** data/LA/my_Training Set/train/origs/12m_053a.nii.gz

idx is = 1 idx is = 39 0%| | 0/15001 [00:01<?, ?it/s] Traceback (most recent call last): File "//miniconda3/envs/caml/CAML/code/train_caml.py", line 143, in for i_batch, sampled_batch in enumerate(trainloader): File "/i/miniconda3/envs/caml/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 634, in next data = self._next_data() ^^^^^^^^^^^^^^^^^ File "//miniconda3/envs/caml/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 1346, in _next_data return self._process_data(data) ^^^^^^^^^^^^^^^^^^^^^^^^ File "//miniconda3/envs/caml/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 1372, in _process_data data.reraise() File "//miniconda3/envs/caml/lib/python3.11/site-packages/torch/_utils.py", line 644, in reraise raise exception IndexError: Caught IndexError in DataLoader worker process 0. Original Traceback (most recent call last): File "//miniconda3/envs/caml/lib/python3.11/site-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop data = fetcher.fetch(index) ^^^^^^^^/miniconda3/envs/caml/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch data = [self.dataset[idx] for idx in possibly_batched_index] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "//miniconda3/envs/caml/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py", line 51, in data = [self.dataset[idx] for idx in possibly_batched_index]


  File "/**/miniconda3/envs/caml/CAML/code/dataloaders/dataset.py", line 123, in __getitem__
    image_name = self.image_list[idx]
                 ~~~~~~~~~~~~~~~^^^^^
IndexError: list index out of range