havakv / pycox

Survival analysis with PyTorch
BSD 2-Clause "Simplified" License
803 stars 188 forks source link

Problem with the indexes and batch size when using fit_loader #127

Closed CarlosHernandezP closed 2 years ago

CarlosHernandezP commented 2 years ago

Hello, and thank you for the wonderful package you guys have created.

I am trying to use PyCox's Logistic Hazard Method with a custom dataset.

My issue is twofold and it arises when using fit_loader with a Dataset that I have created.

  1. Even though I explicitly indicate batch_size=1 on the DataLoader it enters the __getitem__ method of my Dataset twice.

  2. The idx that I receive at each calling of __getitem__ is not an integer (as I'm used to with my past Pytorch experience) but a slice with the form of (none, 2, none).

Here is my Dataset for reproducibility:

        class Generic_MIL_Dataset_Survival(Dataset):
                def __init__(self,
                slide_data,
                data_dir,
                time, event, inferring=False):

        super(Generic_MIL_Dataset_Survival, self).__init__()
        self.data_dir = data_dir
        # DataFrame with our information
        self.slide_data = slide_data 
        # Needed for PyCox
        self.time, self.event = tt.tuplefy(time, event).to_tensor()

        # Needed for the test set later
        self.infer = inferring

    def __len__(self):
        # return the length of the df as the length of the dataset class
        return len(self.slide_data)

    def __getitem__(self, idx):
        slide_id = self.slide_data['slide_id'].iloc[idx]

        if self.data_dir:
          #  import ipdb;ipdb.set_trace()
            full_path = os.path.join(self.data_dir, 'pt_files', '{}.pt'.format(slide_id))
            features = torch.load(full_path)

            if not self.infer:
                return features, (self.time[idx], self.event[idx])
            else:
                return features
        else:
            import ipdb;ipdb.set_trace()
            return slide_id

        return 0

Needless to say, I'm using the propper collate_fn:

# Collate_fn 
def collate_fn_wsi(batch):
    """Stacks the entries of a nested tuple"""
    return tt.tuplefy(batch).stack()

Whenever I instantiate a new DataLoader: train_loader = DataLoader(dataset_train, batch_size=1, num_workers=0, collate_fn=collate_fn_wsi)

If you need anything extra, please, let me know.

Thank you very much for your time.