DeepGraphLearning / torchdrug

A powerful and flexible machine learning platform for drug discovery
https://torchdrug.ai/
Apache License 2.0
1.44k stars 200 forks source link

How are missing values treated in regression model training for property prediction? #37

Open schinto opened 3 years ago

schinto commented 3 years ago

Hi, I tried to build a property prediction model using the OPV dataset. See code below. Training a GIN model using all 8 tasks fails due to missing values in the 4 subtasks ending in _extrapolated. However, model training does not stop even when all values get nan. When the 4 subtasks with missing values are excluded model training works fine.

How does torchdrug deal with missing values in subtasks?

I'm asking as I would like to find out how robust multitask GIN models are to data sparsity. See Effect of missing data on multitask prediction methods

import torch
from torchdrug import core, data, datasets, tasks, models

dataset = datasets.OPV("~/molecule-datasets/")
train_set, valid_set, test_set = dataset.split()
print(f"# Train/Valid/Test: {len(train_set)}/{len(valid_set)}/{len(test_set)}")

model = models.GIN(
    input_dim=dataset.node_feature_dim,
    hidden_dims=[300, 300, 300, 300],
    short_cut=True,
    batch_norm=True,
    concat_hidden=True,
)
subtasks = (
    "gap",
    "homo",
    "lumo",
    "spectral_overlap",
    # "homo_extrapolated", # task contains nan values
    # "lumo_extrapolated", # task contains nan values
    # "gap_extrapolated", # task contains nan values
    # "optical_lumo_extrapolated", # task contains nan values
)
task = tasks.PropertyPrediction(
    model, task=subtasks, criterion="mse", metric=("mae", "rmse"), verbose=1
)
optimizer = torch.optim.Adam(task.parameters(), lr=1e-3)
solver = core.Engine(
    task,
    train_set,
    valid_set,
    test_set,
    optimizer,
    gpus=[0],
    batch_size=256,
)
solver.train(num_epoch=3)
solver.save("opv_gin_property_prediction.pth")
solver.evaluate("valid")
solver.evaluate("test")

Thanks!

KiddoZhu commented 3 years ago

Hi! Currently TorchDrug doesn't handle missing values in the dataset. One solution is to mask out the samples with missing labels by creating a sample mask and apply it in torch.utils.data.Subset, though this may not best fit multi-task setting.

In general, we may accept a feature request for missing values in property prediction. However, I am not sure what a robust solution is for missing values in multi-task predictions. If the community can agree at some solutions, we can include it as a part of TorchDrug.