DeepGraphLearning / torchdrug

A powerful and flexible machine learning platform for drug discovery
https://torchdrug.ai/
Apache License 2.0
1.42k stars 199 forks source link

How to add a custom data? #27

Open schinto opened 3 years ago

schinto commented 3 years ago

Please add a description or tutorial how to load custom data.

I would like to use the clinical photosensitivity (PIH) data published by Schmidt et al Chem. Res. Toxicol. 2019, 32, 2338−2352. The data can be downloaded as supplementary material EXCEL file tx9b00338_si_001.xls Table S1 contains a column with the SMILES, the PIH value and a Set column, with indicates the splits.

Many thanks

KiddoZhu commented 3 years ago

Hi! xls is a Windows-specific format and you need to convert it to either csv or tsv format. To add custom datasets, you may follow the implementation of existing datasets, e.g. ClinTox.

To use column data in a table file for generating splits, you may follow the example of MOSES dataset.

schinto commented 3 years ago

Hi, as you proposed, I prepared the data as CSV file

import pandas as pd
# Excel file downloaded from the supplementary material of
# https://pubs.acs.org/doi/abs/10.1021/acs.chemrestox.9b00338
# Link to EXCEL sheet
# https://pubs.acs.org/doi/suppl/10.1021/acs.chemrestox.9b00338/suppl_file/tx9b00338_si_001.xls
xls_file = "./molecule-datasets/tx9b00338_si_001.xls"
# CSV output file
csv_file = "./molecule-datasets/phototox_pih.csv"
# Reading XLS files requires installation of xlrd
df = pd.read_excel(xls_file, sheet_name="PIH", engine="xlrd")
# Replace newlines
df['Canonical_Smiles'] = df['Canonical_Smiles'].str.replace("\n","")
df['Substance'] = df['Substance'].str.replace("\n",";")
# Map positive and negative class
df['Photosensitation'].replace({'no': 0, 'yes': 1}, inplace=True)
# Remove missing smiles
df = df[df['Canonical_Smiles'].notnull()]
# Keep required columns
df = df[['Substance', 'Canonical_Smiles', 'Set', 'Photosensitation']]
df.to_csv(csv_file, index=False)

But how to do the split?

I tried the following, but it's not clear if the column Set, which contains the split info, is used as target?

import os

from torchdrug import data, utils
from torchdrug.core import Registry as R
from torchdrug.utils import doc

@R.register("datasets.PhotoTox")
@doc.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields"))
class PhotoTox(data.MoleculeDataset):
    """
    Qualitative data of drugs approved by the FDA and those that have failed clinical
    trials for toxicity reasons.
    Statistics:
        - #Molecule: 1,417
        - #Classification task: 1
    Parameters:
        path (str): path to store the dataset
        verbose (int, optional): output verbose level
        **kwargs
    """

    csv_file = "./molecule-datasets/phototox_pih.csv"
    target_fields = ["Photosensitation", "Set"]

    def __init__(self, path, verbose=1, **kwargs):
        self.load_csv(self.csv_file, smiles_field="Canonical_Smiles", target_fields=self.target_fields,
                      verbose=verbose, **kwargs)

    def split(self):
        indexes = defaultdict(list)
        for i, split in enumerate(self.targets["Set"]):
            indexes[split].append(i)
        train_set = torch_data.Subset(self, indexes["Train"])
        valid_set = torch_data.Subset(self, indexes["Test"]) # No validation set given in the article, but a Test and an External set
        test_set = torch_data.Subset(self, indexes["Ext"])
        return train_set, valid_set, test_set
KiddoZhu commented 3 years ago

That's fine. Loading some fields as targets only means that such information will be kept in each sample, but not necessarily used for training. If you set tasks = ("Photosensitation,") in the arguments of PropertyPrediction, then the model will only be trained to predict photosensitation.

schinto commented 3 years ago

Many thanks!

In table 4 of the corresponding article an accuracy of 85% of the test and validation set is reported for Photosensitation (PIH) prediction by a random forest model.

In the PropertyPrediction class only the Area under the Precision-Recall Curve (auprc) and the ROC-AUC (auroc) are available as metrics. What about the other classification metrics like accuracy, sensitivity, specificity, F-score precision, recall error rate and confusion matrix?

Table S-6 in the supplementary material of the article reports for their best DNN a ROC-AUC of 0.810 (valid) and 0.867 (test). The GIN model in the following code resulted in a ROC-AUC of 0.681 (valid) and 0.644 (test) If the code below is correct, then it's a rather low performance of the GIN model.

from collections import defaultdict

from torch.utils import data as torch_data

from torchdrug import data
from torchdrug.core import Registry as R
from torchdrug.utils import doc

@R.register("datasets.PhotoTox")
@doc.copy_args(data.MoleculeDataset.load_csv, ignore=("smiles_field", "target_fields"))
class PhotoTox(data.MoleculeDataset):
    """
    Qualitative data of drugs approved by the FDA and those that have failed clinical
    trials for toxicity reasons.
    Statistics:
        - #Molecule: 1,417
        - #Classification task: 1
    Parameters:
        path (str): path to store the dataset
        verbose (int, optional): output verbose level
        **kwargs
    """

    csv_file = "./molecule-datasets/phototox_pih.csv"
    target_fields = ["Photosensitation", "Set"]

    def __init__(self, path, verbose=1, **kwargs):
        self.load_csv(self.csv_file, smiles_field="Canonical_Smiles", target_fields=self.target_fields,
                      verbose=verbose, **kwargs)

    def split(self):
        indexes = defaultdict(list)
        for i, split in enumerate(self.targets["Set"]):
            indexes[split].append(i)
        train_set = torch_data.Subset(self, indexes["Train"])
        valid_set = torch_data.Subset(self, indexes["Test"])
        test_set = torch_data.Subset(self, indexes["Ext"])
        return train_set, valid_set, test_set
import torch
from torchdrug import core, models, tasks

dataset = PhotoTox("~/Projects/drugs/molecule-datasets/")
train_set, valid_set, test_set = dataset.split()

model = models.GIN(input_dim=dataset.node_feature_dim,
                   hidden_dims=[256, 256, 256, 256],
                   short_cut=True, batch_norm=True, concat_hidden=True)
task = tasks.PropertyPrediction(model, task=("Photosensitation",),
                                criterion="bce", metric=("auprc", "auroc"))

optimizer = torch.optim.Adam(task.parameters(), lr=1e-3)
solver = core.Engine(task, train_set, valid_set, test_set, optimizer,
                     #gpus=[0],
                     batch_size=1024)
solver.train(num_epoch=100)
solver.evaluate("valid")
solver.evaluate("test")
KiddoZhu commented 3 years ago

Currently we only have AUROC and AUPRC as they are the most common metrics in property prediction. If you have any suggestion about other commonly used metrics (e.g. used in at least 3 popular articles), we will consider adding them.

I read your code and it seems everything looks good. However, reproducing property prediction performance involves many details, like the choice of atom and bond features, and the integration of bond features in message passing, and even the number of MLP layers for final prediction. Our boilerplate code (the one you used) is mostly tuned on MoleculeNet datasets, but is not guarateed to be optimal for other datasets.

Btw, GitHub issue is a place for discussing bugs, documentation and feature request. It isn't a good place for asking reproduction questions about a certain paper (especially on datasets not covered by the library), since it is not general enough for the community. Thanks for your understanding.

schinto commented 3 years ago

As the manual selection of parameters (like the choice of atom and bond features, integration of bond features in message passing, number of MLP layers, etc.) for a graph neural network is difficult, can you please add some automated machine learning techniques to automatically determine the best parameters? Like techniques used in AutoGL

KiddoZhu commented 3 years ago

For hyperparameter search, you can just do simple grid search or random search in the hyperparameter space. If you want to do very large-scale search, there are also many out-of-the-box libraries for this purpose, like Ray-Tune. As these libraries are highly developed and should be compatible with our library, I feel there is no sufficient reason to maintain a hyperparamter search code in TorchDrug, or to let users learn our logic of hyperparameter search.