divelab / DIG

A library for graph deep learning research
https://diveintographs.readthedocs.io/
GNU General Public License v3.0
1.86k stars 281 forks source link

'data_batch.batch' in SphereNet model, What does this mean and how do you get it? #161

Closed flpan closed 1 year ago

flpan commented 1 year ago

in https://github.com/divelab/DIG/blob/dig-stable/dig/threedgraph/method/spherenet/spherenet.py In line 285: def forward(self, batch_data): z, pos, batch = batch_data.z, batch_data.pos, batch_data.batch

I know that 'z' represents' atom_type 'and 'pos' represents 3D coordinates. But I don't know how this batch was obtained. Could you please help me with it? I didn't find out how in the qm9/qm17 datasets. Thank you very much!

limei0307 commented 1 year ago

Hi @flpan,

We use DataLoader in PyG to obtain each batch of data. You can find the code in run.py (line 53 and line 121). We also provide an example to train the model. Please let me know if you have further questions. Thanks!

Best

flpan commented 1 year ago

OK, thanks for your reply! I have some solutions. First, I use torch_ Geometric. data. Data Create data and then use torch_ Geometric. data. Dataloader for processing. Best!