mahmoodlab / SurvPath

Modeling Dense Multimodal Interactions Between Biological Pathways and Histology for Survival Prediction - CVPR 2024
103 stars 4 forks source link

Some bugs occurred during data loading. #6

Closed HarveyYi closed 6 months ago

HarveyYi commented 7 months ago

I am grateful to the authors for providing such exceptional work. This work have greatly inspired me. While I was replicating the code, I encountered a bug related to mismatched modal data during the data loading process. Specifically, the issue is as follows: image

During the debugging process, I noticed that when utilizing multimodal data, the case_id and slide_id of the current case do not align with the temp index in the omics data.

I found that the issue was caused by the following code in datasets/dataset_survival.py:

        elif self.modality == "coattn":
            patch_features, mask = self._load_wsi_embs_from_path(self.data_dir, slide_ids)
            omic1 = torch.tensor(self.omics_data_dict["rna"][self.omic_names[0]].iloc[idx])
            omic2 = torch.tensor(self.omics_data_dict["rna"][self.omic_names[1]].iloc[idx])
            omic3 = torch.tensor(self.omics_data_dict["rna"][self.omic_names[2]].iloc[idx])
            omic4 = torch.tensor(self.omics_data_dict["rna"][self.omic_names[3]].iloc[idx])
            omic5 = torch.tensor(self.omics_data_dict["rna"][self.omic_names[4]].iloc[idx])
            omic6 = torch.tensor(self.omics_data_dict["rna"][self.omic_names[5]].iloc[idx])

            return (patch_features, omic1, omic2, omic3, omic4, omic5, omic6, label, event_time, c, clinical_data, mask)

        elif self.modality == "survpath":
            patch_features, mask = self._load_wsi_embs_from_path(self.data_dir, slide_ids)
            omic_list = []
            for i in range(self.num_pathways):
                omic_list.append(torch.tensor(self.omics_data_dict["rna"][self.omic_names[i]].iloc[idx]))

            return (patch_features, omic_list, label, event_time, c, clinical_data, mask)`

I modified the code as follows, and it is now running correctly.

        elif self.modality in ["coattn", "coattn_motcat"]:
            patch_features, mask = self._load_wsi_embs_from_path(self.data_dir, slide_ids)
            omic1 = torch.tensor(self.omics_data_dict["rna"][self.omics_data_dict["rna"]["temp_index"] == case_id][self.omic_names[0]].values[0])
            omic2 = torch.tensor(self.omics_data_dict["rna"][self.omics_data_dict["rna"]["temp_index"] == case_id][self.omic_names[1]].values[0])
            omic3 = torch.tensor(self.omics_data_dict["rna"][self.omics_data_dict["rna"]["temp_index"] == case_id][self.omic_names[2]].values[0])
            omic4 = torch.tensor(self.omics_data_dict["rna"][self.omics_data_dict["rna"]["temp_index"] == case_id][self.omic_names[3]].values[0])
            omic5 = torch.tensor(self.omics_data_dict["rna"][self.omics_data_dict["rna"]["temp_index"] == case_id][self.omic_names[4]].values[0])
            omic6 = torch.tensor(self.omics_data_dict["rna"][self.omics_data_dict["rna"]["temp_index"] == case_id][self.omic_names[5]].values[0])

            return (patch_features, omic1, omic2, omic3, omic4, omic5, omic6, label, event_time, c, clinical_data, mask)

        elif self.modality == "survpath":
            patch_features, mask = self._load_wsi_embs_from_path(self.data_dir, slide_ids)
            omic_list = []
            for i in range(self.num_pathways):
                omic_list.append(torch.tensor(self.omics_data_dict["rna"][self.omics_data_dict["rna"]["temp_index"] == case_id][self.omic_names[i]].values[0]))

            return (patch_features, omic_list, label, event_time, c, clinical_data, mask)

I'm not sure if this is an isolated case. I feel documenting the situation could help with reproducing this work, so I raised this issue.

Finally, thank you once again for your great work.