Cassie07 / PathOmics

[MICCAI 2023 Oral] The official code of "Pathology-and-genomics Multimodal Transformer for Survival Outcome Prediction" (top 9%)
73 stars 10 forks source link

Stack gene data #2

Closed DaniMlk closed 8 months ago

DaniMlk commented 8 months ago

Hi, I have a problem regarding stacking the gene data. In your code when we collect x_omic they collect genes from gene_family_dict and they can be in different sizes, so at the end, we have a list of tensors with different sizes. When we want to make the dataloader and in collate_fn how do you manage that to stack tensors with different sizes?

Cassie07 commented 8 months ago

Hello, @DaniMlk . Thanks for your interest in our study. We didn't stack the tensors when we made our dataloader. We kept using the list of tensors as the input of our model. If the length of the gene tensor list (i.e., the group number of genes is 8 in our study) is the same among different patients, there should not be any issue with making a dataloader.

Here is an example to better understand how our genomics data is loaded.

x_path, x_omics, censorship, survival_months, label, patient_id = next(iter(dataloader))

print(len(x_omic))
> 8

print([x.shape for x in x_omic])
> [torch.Size(137), torch.Size(623), torch.Size(1859), torch.Size(383), torch.Size(447), torch.Size(506), torch.Size(430), torch.Size(326)]

We use the load_genomics_z_score function to get the list of genes. Then, we directly use CustomizedDataset function to construct our dataset. In this process, gene data are always saved in a list without being stacked.

We stack our gene representations after we send data into the model. Please check code of PathOmics_Survival_model (the line 156-167) . I also attached them as follows.

    if self.omic_bag == 'Attention':
        h_omic = [self.omics_attention_networks[idx].forward(sig_feat) for idx, sig_feat in enumerate(x_omic)]
    elif self.omic_bag == 'SNN_Attention':
        h_omic = []
        for idx, sig_feat in enumerate(x_omic):
            snn_feat = self.sig_networks[idx].forward(sig_feat)
            snn_feat = snn_feat.unsqueeze(0).unsqueeze(1)
            attention_feat,_ = self.mutiheadattention_networks[idx].forward(snn_feat,snn_feat,snn_feat)
            h_omic.append(attention_feat.squeeze())
    else:
        h_omic = [self.sig_networks[idx].forward(sig_feat) for idx, sig_feat in enumerate(x_omic)] ### each omic signature goes through it's own FC layer

    h_omic_bag = torch.stack(h_omic)### numGroup x 256

Please feel free to let me know if you have any questions. Also, please cite our paper if you find it helpful for your research.

Bests, Kexin

DaniMlk commented 8 months ago

Thank you so much for your prompt reply. So, by not stacking the data, we can not use a batch size more than 1. Is that a valid statement?

Cassie07 commented 8 months ago

Thank you so much for your prompt reply. So, by not stacking the data, we can not use a batch size more than 1. Is that a valid statement?

@DaniMlk In our study, the batch size is 1. However, this batch size setting is not because of the various sizes of the gene tensors. As mentioned before, even if each group of the gene tensors has a different size, we have already aggregated them into a list, which has a consistent size (length) of 8 for each patient. This operation can ensure that different sizes of gene tensors will not affect your use of the dataloader.

x_path, x_omics, censorship, survival_months, label, patient_id = next(iter(dataloader))

print(len(x_omic))
> 8

If you want to use a larger batch size, the number of patches in each patient slide is actually related to the concern you mentioned before. For each patient, they have a different number of patches. In this way, they cannot be stacked in dataloader. You may need to ensure the patients in each batch have the same number of patches or modify your dataloader collate_fn. But for the current gene list setting, I think it will not hinder you from increasing the batch size.