lrjconan / GRAN

Efficient Graph Generation with Graph Recurrent Attention Networks, Deep Generative Model of Graphs, Graph Neural Networks, NeurIPS 2019
MIT License
462 stars 97 forks source link

Division By Zero #7

Closed kevtran23 closed 3 years ago

kevtran23 commented 4 years ago

Hi Renjie,
I downloaded the gran_grid.pth model using the download_model.sh script and ran !python run_exp.py -c config/gran_grid.yaml -t in Google Colab and ran into the following error. Any suggestions what might be wrong? I also didn't change anything in the gran_grid.yaml file.

INFO | 2020-01-30 16:04:38,856 | run_exp.py | line 26 : Writing log file to exp/GRAN/GRANMixtureBernoulli_grid_2020-Jan-30-16-04-38_1211/log_exp_1211.txt INFO | 2020-01-30 16:04:38,857 | run_exp.py | line 27 : Exp instance id = 1211 INFO | 2020-01-30 16:04:38,857 | run_exp.py | line 28 : Exp comment = None INFO | 2020-01-30 16:04:38,857 | run_exp.py | line 29 : Config =

{'dataset': {'data_path': 'data/', 'dev_ratio': 0.2, 'has_node_feat': False, 'is_overwrite_precompute': False, 'is_sample_subgraph': True, 'is_save_split': False, 'loader_name': 'GRANData', 'name': 'grid', 'node_order': 'DFS', 'num_fwd_pass': 1, 'num_subgraph_batch': 50, 'train_ratio': 0.8}, 'device': 'cuda:0', 'exp_dir': 'exp/GRAN', 'exp_name': 'GRANMixtureBernoulli_grid_2020-Jan-30-16-04-38_1211', 'gpus': [0], 'model': {'block_size': 1, 'dimension_reduce': True, 'edge_weight': 1.0, 'embedding_dim': 128, 'has_attention': True, 'hidden_dim': 128, 'is_sym': True, 'max_num_nodes': 361, 'name': 'GRANMixtureBernoulli', 'num_GNN_layers': 7, 'num_GNN_prop': 1, 'num_canonical_order': 1, 'num_mix_component': 20, 'sample_stride': 1}, 'run_id': '1211', 'runner': 'GranRunner', 'save_dir': 'exp/GRAN/GRANMixtureBernoulli_grid_2020-Jan-30-16-04-38_1211', 'seed': 1234, 'test': {'batch_size': 20, 'better_vis': True, 'is_single_plot': False, 'is_test_ER': False, 'is_vis': True, 'num_test_gen': 20, 'num_vis': 20, 'num_workers': 0, 'test_model_dir': 'snapshot_model', 'test_model_name': 'gran_grid.pth', 'vis_num_row': 5}, 'train': {'batch_size': 1, 'display_iter': 10, 'is_resume': False, 'lr': 0.0001, 'lr_decay': 0.3, 'lr_decay_epoch': [100000000], 'max_epoch': 3000, 'momentum': 0.9, 'num_workers': 0, 'optimizer': 'Adam', 'resume_dir': None, 'resume_epoch': 5000, 'resume_model': 'model_snapshot_0005000.pth', 'shuffle': True, 'snapshot_epoch': 100, 'valid_epoch': 50, 'wd': 0.0}, 'use_gpu': True, 'use_horovod': False} <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< max # nodes = 361 || mean # nodes = 210.25 max # edges = 684 || mean # edges = 391.5 INFO | 2020-01-30 16:04:38,984 | gran_runner.py | line 124 : Train/val/test = 80/20/20 INFO | 2020-01-30 16:04:38,988 | gran_runner.py | line 137 : No Edges vs. Edges in training set = 111.70632737276479 100% 1/1 [00:09<00:00, 9.14s/it] INFO | 2020-01-30 16:04:51,079 | gran_runner.py | line 314 : Average test time per mini-batch = 9.139426708221436 /usr/local/lib/python3.6/dist-packages/networkx/drawing/nx_pylab.py:579: MatplotlibDeprecationWarning: The iterable function was deprecated in Matplotlib 3.1 and will be removed in 3.3. Use np.iterable instead. if not cb.iterable(width): ERROR | 2020-01-30 16:05:20,040 | run_exp.py | line 42 : Traceback (most recent call last): File "run_exp.py", line 40, in main runner.test() File "/content/gdrive/My Drive/GRAN/runner/gran_runner.py", line 370, in test mmd_degree_dev, mmd_clustering_dev, mmd_4orbits_dev, mmd_spectral_dev = evaluate(self.graphs_dev, graphs_gen, degree_only=False) File "/content/gdrive/My Drive/GRAN/runner/gran_runner.py", line 77, in evaluate mmd_4orbits = orbit_stats_all(graph_gt, graph_pred) File "/content/gdrive/My Drive/GRAN/utils/eval_helper.py", line 396, in orbit_stats_all sigma=30.0) File "/content/gdrive/My Drive/GRAN/utils/dist_helper.py", line 157, in compute_mmd disc(samples2, samples2, kernel, *args, *kwargs) - \ File "/content/gdrive/My Drive/GRAN/utils/dist_helper.py", line 139, in disc d /= len(samples1) len(samples2) ZeroDivisionError: division by zero

lrjconan commented 4 years ago

Hi, I think it is complaining that the length of samples1 passed to disc function is 0. I suspect that you did not successfully run the evaluation program orca (located in utils folder) in your Google Colab environment. Could you check the Evaluation part to make sure that you successfully compile the ORCA?

KyleAMoore commented 3 years ago

I found that I continued to receive a ZeroDivisionError even after compiling orca. I found that the culprit was lines 255 and 283 of eval_helper.py. The call to str.find was always returning 0 for the index of the hardcoded COUNT_START_STR. I suspect that this may be caused by Linux/Windows line ending disparity (I encountered the issue while running Windows 10), but I cannot verify this. The fix I found for this was in two parts:

  1. Change line 255 from COUNT_START_STR = 'orbit counts:' \n to COUNT_START_STR = 'orbit counts:'
  2. Change line 283 from idx = output.find(COUNT_START_STR) + len(COUNT_START_STR) to idx = output.find(COUNT_START_STR) + len(COUNT_START_STR) + 2

I suspect that this change should work for any OS and I recommend adding it to the existing codebase if so.

lrjconan commented 3 years ago

I found that I continued to receive a ZeroDivisionError even after compiling orca. I found that the culprit was lines 255 and 283 of eval_helper.py. The call to str.find was always returning 0 for the index of the hardcoded COUNT_START_STR. I suspect that this may be caused by Linux/Windows line ending disparity (I encountered the issue while running Windows 10), but I cannot verify this. The fix I found for this was in two parts:

  1. Change line 255 from COUNT_START_STR = 'orbit counts:' \n to COUNT_START_STR = 'orbit counts:'
  2. Change line 283 from idx = output.find(COUNT_START_STR) + len(COUNT_START_STR) to idx = output.find(COUNT_START_STR) + len(COUNT_START_STR) + 2

I suspect that this change should work for any OS and I recommend adding it to the existing codebase if so.

Thanks for spotting this! Do you mind submitting a pull request?