awslabs / dgl-lifesci

Python package for graph neural networks in chemistry and biology
Apache License 2.0
696 stars 144 forks source link

How to obtain the categorical node features for GIN and GNNOGB? #220

Open Wang-Lin-boop opened 11 months ago

Wang-Lin-boop commented 11 months ago

When calling the GIN or GNNOGB modules, the forward function requires a LongTensor storing the node categorical features, which looks like it may be an attributes of DGLGraph, but I can't find a way to get it in the DGL documentation.

If I can get a dictionary from BaseAtomFeaturizer to store node categorical ID and its features, a function that seems easy to implement manually as well, does BaseAtomFeaturizer support getting a dictionary of a node's categorical features?

Progressively refining this dictionary by analyzing the nodes of each batch seems like a viable option, but it doesn't seem elegant enough that I can't believe it's the right one. Can anyone tell me how you solved this problem? I'am looking forward to you reply.

Thanks!

mufeili commented 11 months ago

When calling the GIN or GNNOGB modules, the forward function requires a LongTensor storing the node categorical features, which looks like it may be an attributes of DGLGraph, but I can't find a way to get it in the DGL documentation.

In the context of molecules, most likely they are atom types. It depends on if your graph already has atom types extracted and stored.

If I can get a dictionary from BaseAtomFeaturizer to store node categorical ID and its features, a function that seems easy to implement manually as well, does BaseAtomFeaturizer support getting a dictionary of a node's categorical features?

Yes, it returns a dictionary of node features, as can be found in the example here: https://github.com/awslabs/dgl-lifesci/blob/master/python/dgllife/utils/featurizers.py#L858

Progressively refining this dictionary by analyzing the nodes of each batch seems like a viable option, but it doesn't seem elegant enough that I can't believe it's the right one. Can anyone tell me how you solved this problem? I'am looking forward to you reply.

It's generally recommended to first extract node features for each molecule and then save them in the corresponding DGLGraph. When you batch multiple graphs corresponding to molecules, their features will be also batched.

Wang-Lin-boop commented 11 months ago

In the context of molecules, most likely they are atom types. It depends on if your graph already has atom types extracted and stored.

For my featurizer BaseAtomFeaturizer({'atom_type':ConcatFeaturizer([atom_type_one_hot, atom_hybridization_one_hot, atom_formal_charge, atom_chiral_tag_one_hot, atom_is_in_ring, atom_is_aromatic])}), the node features is a long vector, look like (for OCCOCC) :

tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 0.,
         0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 0.,
         0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 0.,
         0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 0.,
         0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 0.,
         0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 0.,
         0.]])

It looks like I can't get node categorical features directly using the existing API.

mufeili commented 11 months ago

You can use https://github.com/awslabs/dgl-lifesci/blob/master/python/dgllife/utils/featurizers.py#L179