Hello, and thank you for the wonderful package you guys have created.
I am trying to use PyCox's Logistic Hazard Method with a custom dataset.
My issue is twofold and it arises when using fit_loader with a Dataset that I have created.
Even though I explicitly indicate batch_size=1 on the DataLoader it enters the __getitem__ method of my Dataset twice.
The idx that I receive at each calling of __getitem__ is not an integer (as I'm used to with my past Pytorch experience) but a slice with the form of (none, 2, none).
Here is my Dataset for reproducibility:
class Generic_MIL_Dataset_Survival(Dataset):
def __init__(self,
slide_data,
data_dir,
time, event, inferring=False):
super(Generic_MIL_Dataset_Survival, self).__init__()
self.data_dir = data_dir
# DataFrame with our information
self.slide_data = slide_data
# Needed for PyCox
self.time, self.event = tt.tuplefy(time, event).to_tensor()
# Needed for the test set later
self.infer = inferring
def __len__(self):
# return the length of the df as the length of the dataset class
return len(self.slide_data)
def __getitem__(self, idx):
slide_id = self.slide_data['slide_id'].iloc[idx]
if self.data_dir:
# import ipdb;ipdb.set_trace()
full_path = os.path.join(self.data_dir, 'pt_files', '{}.pt'.format(slide_id))
features = torch.load(full_path)
if not self.infer:
return features, (self.time[idx], self.event[idx])
else:
return features
else:
import ipdb;ipdb.set_trace()
return slide_id
return 0
Needless to say, I'm using the propper collate_fn:
# Collate_fn
def collate_fn_wsi(batch):
"""Stacks the entries of a nested tuple"""
return tt.tuplefy(batch).stack()
Whenever I instantiate a new DataLoader:
train_loader = DataLoader(dataset_train, batch_size=1, num_workers=0, collate_fn=collate_fn_wsi)
Hello, and thank you for the wonderful package you guys have created.
I am trying to use PyCox's Logistic Hazard Method with a custom dataset.
My issue is twofold and it arises when using fit_loader with a Dataset that I have created.
Even though I explicitly indicate
batch_size=1
on the DataLoader it enters the__getitem__
method of my Dataset twice.The
idx
that I receive at each calling of__getitem__
is not an integer (as I'm used to with my past Pytorch experience) but aslice
with the form of(none, 2, none)
.Here is my Dataset for reproducibility:
Needless to say, I'm using the propper collate_fn:
Whenever I instantiate a new DataLoader:
train_loader = DataLoader(dataset_train, batch_size=1, num_workers=0, collate_fn=collate_fn_wsi)
If you need anything extra, please, let me know.
Thank you very much for your time.