yuanqidu / M2Hub

MIT License
37 stars 4 forks source link

`otf_graph` always set to `True` (?) #2

Open fedeotto opened 6 months ago

fedeotto commented 6 months ago

Hi, I'm not sure if I'm doing anything wrong but I have noticed that otf_graph gets always set to True when I'm trying to train a simple cgcnn model. I believe this is caused by data.neighbors not being created at preprocessing stage (using AtomsToPeriodicGraphs), so, does this mean that all initial attributes in the latter (like max_neigh) are systematically ignored? Also, I'm trying to understand the difference between get_pbc_distances reported in this repository (which indeed uses data.neighbors, missing at data preprocessing stage) against the original one seen in the CDVAE implementation, that utilizes num_bonds attribute instead (see below, I just adapted it a little bit to fit the context):

def get_pbc_distances_cdvae(
    pos,
    edge_index,
    cell,
    cell_offsets,
    num_atoms,
    num_bonds,
    return_offsets=False,
    return_distance_vec=False,
):

    j_index, i_index = edge_index

    distance_vectors = pos[j_index] - pos[i_index]

    # correct for pbc
    lattice_edges = torch.repeat_interleave(cell, num_bonds, dim=0)
    offsets = torch.einsum('bi,bij->bj', cell_offsets.float(), lattice_edges)
    distance_vectors += offsets

    # compute distances
    distances = distance_vectors.norm(dim=-1)

    out = {
        "edge_index": edge_index,
        "distances": distances,
    }

    if return_distance_vec:
        out["distance_vec"] = distance_vectors

    if return_offsets:
        out["offsets"] = offsets

    return out
yuanqidu commented 6 months ago

Thanks for the questions!

otf_graph does not always need to be set to True, we have a tutorial here we set it to False

By the code you copied from CDVAE, our get_pbc_distances functions should be the same as in CDVAE, see here.

yuanqidu commented 6 months ago

I think I get your concerns, for this, we followed the OpenCatalyst project where for simple models, the data is not calculated on the fly, if you check the model file for e.g. SchNet and CGCNN, we didn't prompt it to check if we need to calculate the data on the fly. Indeed, we have default setting in the AtomsToPeriodicGraphs function (radius 6, max neigh 50).

fedeotto commented 6 months ago

Thanks for the answer. To summarize, what I have noticed is that data.neighbors won't be created during data processing in AtomsToPeriodicGraphs. This will always lead otf_graph = True in generate_graph method:

def generate_graph(
        self,
        data,
        cutoff=None,
        max_neighbors=None,
        use_pbc=None,
        otf_graph=None,
    ):
        cutoff = cutoff or self.cutoff
        max_neighbors = max_neighbors or self.max_neighbors
        use_pbc = use_pbc or self.use_pbc
        otf_graph = otf_graph or self.otf_graph

        if not otf_graph:
            try:
                edge_index = data.edge_index

                if use_pbc:
                    cell_offsets = data.cell_offsets
                    neighbors = data.neighbors 
                    empty_image = neighbors == 0

            except AttributeError:
                logging.warning(
                    "Turning otf_graph=True as required attributes not present in data object"
                )
                otf_graph = True
yuanqidu commented 5 months ago

Thanks for the summary! Now I understand your question! I double checked it and found you are right, neighbors are not created in the function AtomsToPeriodicGraphs. Interesting, we followed how OpenCatalystProject implemented this, it seems they didn't do so as well. Have you played with the OCP code? I feel one simple solution is to add the neighbors field in the data when created with AtomsToPeriodicGraphs. Have you fixed it?

fedeotto commented 5 months ago

@yuanqidu thanks for your reply. As you mentioned, I think it should be enough adding

data.neighbors = torch.tensor([data.edge_index.shape[1]], dtype=torch.long)

attribute when creating a pyg Data object in the convert() method of AtomsToPeriodicGraphs. This was indeed a way to do it that I had found in the CDVAE repository, where data.neighbors is named data.num_bonds https://github.com/txie-93/cdvae/blob/f857f598d6f6cca5dc1ea0582d228f12dcc2c2ea/cdvae/pl_data/dataset.py#L66.