microsoft / Graphormer

Graphormer is a general-purpose deep learning backbone for molecular modeling.
MIT License
2k stars 324 forks source link

Graphormer module #133

Open paridhimaheshwari2708 opened 1 year ago

paridhimaheshwari2708 commented 1 year ago

How can I use the graphormer model with custom dataloader and training scripts (not the fairseq commands)? My data consists of DGL graphs and my setup uses DGL GraphConv layers. I want to use Graphormer as a torch.nn.Module (like any other GNN layer) and encode DGL graphs in my setup. How can I use Graphormer model alone and replace DGL GraphConv layers with Graphormer layers?

mavisguan commented 1 year ago

Hi! Based on our current implementation, it might take some extra effort to satisfy your need. You can try to wrap up code directly related to Graphormer model (which is scattered among 4~5 files, like graphormer/models/graphormer.py, graphormer/tasks/graph_prediction.py, etc.) into a single file. For this, you can also refer to fairseq's build_model funtion implementation. Then you can import it as a normal python module.

paridhimaheshwari2708 commented 1 year ago

@mavisguan Hi, thank you for your suggestion. Could you also provide more information about the dataloader? I have a custom dataset of DGL graphs and a task-specific sampling that happens in the dataloader. How can I wrap it into the format needed by Graphormer? Specifically, what does the input to graphormer (batched_data in code snippet here) look like?

mavisguan commented 1 year ago

@paridhimaheshwari2708 You can use an example DGL dataset (like qm7b) and put a breakpoint before the line of code you've mentioned, and see how batched_data looks like: image I use this debugging toolkit to set breakpoints in big python projects: https://github.com/volltin/vpack. I think it's a very helpful tool for digging into Graphormer's code. We're sorry that our tutorials on customizing datasets is incomplete and a bit ambiguous, and we're working on updating our tutorials, please stay tuned.

paridhimaheshwari2708 commented 1 year ago

@mavisguan Thank you, this is really helpful! I noticed here that DGL graphs are actually converted into PyG graphs. Is that right? If so, is batched_data obtained from torch_geometric.loader.DataLoader or is it a dictionary with the above keys?

mavisguan commented 1 year ago

Yes, you're right. I think batched_data is obtained from torch_geometric.loader.DataLoader, which contains batch_size PYG graphs, and it's also a dictionary. image

laowu-code commented 1 year ago

Yes, you're right. I think batched_data is obtained from torch_geometric.loader.DataLoader, which contains batch_size PYG graphs, and it's also a dictionary. image Sorry to bother you,I am wondering how I can convert my dataset(not the dgl,obg or pyg,but the graph adj)into the type that fits batched_data,thanks a lot.

mavisguan commented 1 year ago

@laowu-code You can try to convert your dataset into PYG's customized dataset, following their official document https://pytorch-geometric.readthedocs.io/en/latest/notes/create_dataset.html, then add this dataset in graphormer/data/pyg_datasets/pyg_dataset_lookup_table.py.

laowu-code commented 1 year ago

@mavisguan Thanks for your reply,I will try it,it means a lot to me.

BrandenKeck commented 3 weeks ago

@paridhimaheshwari2708 You can use an example DGL dataset (like qm7b) and put a breakpoint before the line of code you've mentioned, and see how batched_data looks like: image I use this debugging toolkit to set breakpoints in big python projects: https://github.com/volltin/vpack. I think it's a very helpful tool for digging into Graphormer's code. We're sorry that our tutorials on customizing datasets is incomplete and a bit ambiguous, and we're working on updating our tutorials, please stay tuned.

Hi mavisguan,

Is it possible to construct this dictionary manually without a dataloader to pass to the model for testing purposes (just to ensure that the modules are setup correctly)? For example:

batched_data = { "idx": , "x": , "attn_bias": , "attn_edge_type": , "spatial_pos": , "degree": , "edge_input": , "y": }

If so, what format is each component? And, does this need to be wrapped as TensorDict, or is it a dictionary of tensors? I just want to load one molecule example to see if I'm doing this correctly - I modified the code to combine "in degree" and "out degree" and remove multi-hop. However, I am becoming confused by this line:

node_feature = self.atom_encoder(x).sum(dim=-2)

I would have expected the 'x' key to be an adjacency matrix, but it appears that this information should include node types for the encoder? But, the same dictionary key is passed to multi-head attention, so I think I'm confusing myself.

BrandenKeck commented 2 weeks ago

For anyone that comes across this that is also struggling with the data structure - I think I understand now after reading wrapper.py. "batched_data" appears to be a python dictionary with tensors of the following shapes: