AstraZeneca / chemicalx

A PyTorch and TorchDrug based deep learning library for drug pair scoring. (KDD 2022)
https://chemicalx.readthedocs.io
Apache License 2.0
708 stars 87 forks source link

Add simple GPU support to `chemicalx.pipeline()` #86

Closed cthoyt closed 2 years ago

cthoyt commented 2 years ago

Closes #65.

Summary

This PR adds idiomatic GPU support by building on the addition of the PackedGraph.to() method in #84. While I like the ability to use accelerate, trying to solve two problems at the same time in #76 seems like it's opening up a big can of worms (esp. without testing)

Changes

Blockers

codecov-commenter commented 2 years ago

Codecov Report

Merging #86 (b08d46a) into main (040fc7b) will decrease coverage by 0.56%. The diff coverage is 83.87%.

Impacted file tree graph

@@            Coverage Diff             @@
##             main      #86      +/-   ##
==========================================
- Coverage   95.84%   95.28%   -0.57%     
==========================================
  Files          32       33       +1     
  Lines        1419     1462      +43     
==========================================
+ Hits         1360     1393      +33     
- Misses         59       69      +10     
Impacted Files Coverage Δ
chemicalx/models/ssiddi.py 98.79% <ø> (+0.06%) :arrow_up:
chemicalx/compat.py 57.14% <57.14%> (ø)
chemicalx/data/drugpairbatch.py 96.00% <91.66%> (-4.00%) :arrow_down:
chemicalx/data/batchgenerator.py 98.30% <100.00%> (ø)
chemicalx/data/drugfeatureset.py 100.00% <100.00%> (ø)
chemicalx/models/deepdds.py 97.43% <100.00%> (ø)
chemicalx/models/deepdrug.py 96.96% <100.00%> (+0.19%) :arrow_up:
chemicalx/models/epgcnds.py 100.00% <100.00%> (ø)
chemicalx/models/gcnbmp.py 97.64% <100.00%> (+0.02%) :arrow_up:
chemicalx/models/mrgnn.py 100.00% <100.00%> (ø)
... and 10 more

Continue to review full report at Codecov.

Legend - Click here to learn more Δ = absolute <relative> (impact), ø = not affected, ? = missing data Powered by Codecov. Last update 040fc7b...b08d46a. Read the comment docs.

cthoyt commented 2 years ago

Alright so after running the following in a google colab, I am confident that this works properly:

!python -m pip install torch
!python -m pip install "torch-scatter>=2.0.8"
!python -m pip install git+https://github.com/cthoyt/chemicalx.git@gpu-pipeline

import itertools

from chemicalx import pipeline
from chemicalx.data import DrugCombDB
from chemicalx.models import DeepSynergy

dataset = DrugCombDB()
model = DeepSynergy(context_channels=dataset.context_channels, drug_channels=dataset.drug_channels)
results = pipeline(
    dataset=dataset,
    model=model,
    batch_size=5120,
    epochs=100,
    context_features=True,
    drug_features=True,
    drug_molecules=False,
    metrics=[
        "roc_auc",
    ],
)
results.summarize()

devices = {tensor.data.device for tensor in itertools.chain(model.parameters(), model.buffers())}
assert len(devices) == 1
assert "cuda" == next(iter(devices)).type
print(devices)

Which output:

100%|██████████| 100/100 [00:55<00:00,  1.80it/s]

Metric       Value
--------  --------
roc_auc   0.836179
{device(type='cuda', index=0)}

Trained 100 epochs in 55 seconds on GPU vs. 64 seconds on CPU. Not a ton of speed-up, maybe this means that there's a lot of overhead for transferring. Since DeepSynergy is a relatively simple model, I think this might be a bad example though. At least in principle it's working

benedekrozemberczki commented 2 years ago

Looks good @cthoyt feel free to merge!