ycq091044 / ManyDG

ICLR 2023 paper - ManyDG - Dataset processing and mode codes
https://openreview.net/forum?id=lcSfirnflpW
23 stars 3 forks source link

Regarding dataset not being accepted for the hospitalization #1

Closed ramgopisetty closed 1 year ago

ramgopisetty commented 1 year ago

Hi,

After reviewing the model.py, there is no information regarding the base method to accept for the hospitalization scenario.


class Base(nn.Module):
    def __init__(self, dataset, device=None, voc_size=None, model=None, ehr_adj=None, ddi_adj=None):
        super(Base, self).__init__()
        self.dataset = dataset
        self.model = model
        if dataset == "seizure":
            self.feature_cnn = FeatureCNN_seizure()
            self.g_net = nn.Sequential(
                nn.Linear(128, 16),
                nn.ReLU(),
                nn.Linear(16, 6)
            )
        elif dataset == "sleep":
            self.feature_cnn = FeatureCNN_sleep()
            self.g_net = nn.Sequential(
                nn.Linear(128, 32),
                nn.ReLU(),
                nn.Linear(32, 5),
            )
        elif dataset == "drugrec": 
            if model == "Retain":
                self.feature_cnn = Retain(voc_size, emb_dim=64, device=device)
            elif model == "GAMENet":
                self.feature_cnn = GAMENet(voc_size, ehr_adj, ddi_adj, emb_dim=64, device=device, ddi_in_memory=True)
            self.g_net = nn.Sequential(
                nn.Linear(64, 64),
                nn.ReLU(),
                nn.Linear(64, voc_size[2]),
            )
        elif dataset == "mortality":
            diagICD2idx, diagstring2idx, labname2idx, physicalexam2idx, treatment2idx, medname2idx = self.model_initialization_params()
            if model == "Transformer":
                self.feature_cnn = Transformer(diagstring2idx, diagICD2idx, physicalexam2idx, treatment2idx, \
                                medname2idx, labname2idx, emb_dim=128, device=device)
            elif model == "L_Concat":
                self.feature_cnn = L_Concat(diagstring2idx, diagICD2idx, physicalexam2idx, treatment2idx, \
                                medname2idx, labname2idx, emb_dim=128, device=device)
            self.g_net = nn.Sequential(
                nn.Linear(128, 16),
                nn.ELU(),
                nn.Linear(16, 1),
            )
ycq091044 commented 1 year ago

Hi, the "readmission" and "mortality" share the exact same pipeline (both are binary classification and use the same set of features, but their labels are extracted differently). You can use the "mortality" pipeline to conduct hospitalization/readmission.