gnina / libmolgrid

Comprehensive library for fast, GPU accelerated molecular gridding for deep learning workflows
https://gnina.github.io/libmolgrid/
Apache License 2.0
144 stars 48 forks source link

Specifying grid centers for all Examples in a batch #120

Closed jacobdurrant closed 6 months ago

jacobdurrant commented 6 months ago

I'm using molgrid to load data from a types file. It centers the grid on the last coordinate set loaded (the ligand in my case). However, I would like to explicitly set the center point for each example. As best I can tell, none of the gmaker.forward functions can accept a molgrid.molgrid.ExampleVec and also a center parameter. Here's the custom solution I'm using:

import molgrid
import torch

batch_size = 16

# Let's use a custom atom typer
def typer(atom):
    """Typers an atom and returns a tuple of floats and a radius."""

    type_vec = None
    if hasattr(atom, "GetHeteroValence"):
        type_vec = [atom.GetAtomicNum(), atom.GetHeteroValence()]
    else:
        type_vec = [atom.GetAtomicNum(), atom.GetExplicitValence()]
    return (type_vec, 1.5)
t = molgrid.PythonCallbackVectorTyper(typer, 2, ["anum", "valence"])

# Create the data set
dataset = molgrid.ExampleProvider(t, default_batch_size=batch_size, cache_structs=True)
dataset.populate("./out/train.types")

# Now generate grids for an entire batch.
gmaker = molgrid.GridMaker()

labels_tensor = torch.zeros(
    (batch_size, dataset.num_labels()),
    dtype=torch.float32,
    device="cuda",
)

data_tensor = torch.zeros(
    (batch_size,) + gmaker.grid_dimensions(4),
    dtype=torch.float32,
    device="cuda",
)

dataset.reset()
batch = dataset.next_batch()
batch.extract_labels(labels_tensor)
gmaker.forward(batch, data_tensor)

# This works (data in data_tensor), but each grid is centered on the ligand
# (last coord set loaded). What if I want to center each grid on a different
# point?

# Let's make some centers for each of the grids in the batch. For simplicity,
# let's just pick the same center for each grid.
centers = [(-2.284000, -3.443000, 30.277000) for _ in range(batch_size)]

# Get the first batch
dataset.reset()
batch = dataset.next_batch()

def forward_batch_with_explicit_centers(
    centers, batch, tensor_out, random_translate=2.0, random_rotation=False
):
    """
    Forward a batch of examples, each centered on a different point.

    Args:
    - centers (list of 3-tuples): The centers of the grids.
    - batch (molgrid.???): The batch of examples.
    - tensor_out (torch.Tensor): The output tensor.
    - random_translate (float): The maximum amount to randomly translate.
    - random_rotation (bool): Whether to randomly rotate the grid.
    """

    # The below works, but doesn't it defeat the acceleration that comes from
    # processing the entire batch at once on a GPU?

    for example_idx, example in enumerate(batch):
        t = molgrid.Transform(
            center=centers[example_idx],
            random_translate=random_translate,
            random_rotation=random_rotation,
        )

        gmaker.forward(example, t, tensor_out[example_idx])

# Get the grids (in data_tensor) for the batch, each grid centered on the
# specified point.
forward_batch_with_explicit_centers(
    centers,
    batch,
    data_tensor,
    random_translate=0,
    random_rotation=True,
)

But I worry that my forward_batch_with_explicit_centers function will run slowly because it loops over the examples in each batch, rather than processing them all at once on the GPU.

Is there a different way you would go about addressing this problem?

Thanks.

dkoes commented 6 months ago

The batch version of the function just does a for loop anyway, so this won't be any less efficient.

jacobdurrant commented 6 months ago

Hi David. Much thanks for the quick reply to this issue. If the batch version of the function uses a for loop anyway, I'll just go ahead and use this version. Thanks again.