divelab / DIG

A library for graph deep learning research
https://diveintographs.readthedocs.io/
GNU General Public License v3.0
1.87k stars 283 forks source link

Error in GraphDF's example #81

Closed gooaah closed 2 years ago

gooaah commented 2 years ago

Dear developers,

I meet a problem in the generation process in examples/ggraph/GraphDF:

Traceback (most recent call last):
  File "/cobra/u/gaoh/generate/DIG/examples/ggraph/GraphDF/run_rand_gen.py", line 41, in <module>
    mols, pure_valids = runner.run_rand_gen(conf['model'], args.model_path, args.num_mols, conf['num_min_node'], conf['num_max_node'], conf['temperature'], conf['atom_list'])
  File "/u/gaoh/conda-envs/dig/lib/python3.9/site-packages/dig/ggraph/method/GraphDF/graphdf.py", line 115, in run_rand_gen
    mol, no_resample, num_atoms = self.model.generate(atom_list=atomic_num_list, min_atoms=num_min_node, max_atoms=num_max_node, temperature=temperature)
  File "/u/gaoh/conda-envs/dig/lib/python3.9/site-packages/dig/ggraph/method/GraphDF/model/graphflow.py", line 98, in generate
    prior_node_dist = torch.distributions.OneHotCategorical(logits=self.node_base_log_probs[i]*temperature[0])
IndexError: index 20 is out of bounds for dimension 0 with size 20

And my input is:

export path_to_the_model='./rand_gen_qm9/rand_gen_ckpt_10.pth'
CUDA_VISIBLE_DEVICES=0 /u/gaoh/conda-envs/dig/bin/python run_rand_gen.py --num_mols=100 --model_path=${path_to_the_model} --data=qm9

Could you please give me some suggestions?

Best, Hao

lyzustc commented 2 years ago

Thank you for pointing out this bug. We find that it is due to an incorrectly configured parameter in examples/ggraph/GraphDf/config/rand_gen_qm9_config_dict.json. We have updated this file.

gooaah commented 2 years ago

Thanks, it works!