cvignac / DiGress

code for the paper "DiGress: Discrete Denoising diffusion for graph generation"
MIT License
339 stars 74 forks source link

Evaluating Trained Models #100

Open Markus28 opened 1 month ago

Markus28 commented 1 month ago

I was wondering how the trained models are intended to be evaluated. I don't believe that the paper states how many samples were used to compute the metrics. The code appears to give some indication but the testing functionality seems broken. Assuming we train a model via:

python main.py +experiment=planar dataset=planar ++hydra.run.dir=<HEAD>/planar_debug

I would expect that we evaluate it on the test set via:

python main.py +experiment=planar dataset=planar general.test_only=<HEAD>/planar_debug/checkpoints/planar/last.ckpt

Unfortunately, this functionality is broken and gives this stack trace:

Traceback (most recent call last):
  File "/fs/gpfs41/lv11/fileset01/pool/pool-krimmel/DiGress/src/main.py", line 244, in <module>
    main()
  File "/fs/pool/pool-krimmel/miniconda3/envs/digress/lib/python3.9/site-packages/hydra/main.py", line 94, in decorated_main
    _run_hydra(
  File "/fs/pool/pool-krimmel/miniconda3/envs/digress/lib/python3.9/site-packages/hydra/_internal/utils.py", line 394, in _run_hydra
    _run_app(
  File "/fs/pool/pool-krimmel/miniconda3/envs/digress/lib/python3.9/site-packages/hydra/_internal/utils.py", line 457, in _run_app
    run_and_report(
  File "/fs/pool/pool-krimmel/miniconda3/envs/digress/lib/python3.9/site-packages/hydra/_internal/utils.py", line 223, in run_and_report
    raise ex
  File "/fs/pool/pool-krimmel/miniconda3/envs/digress/lib/python3.9/site-packages/hydra/_internal/utils.py", line 220, in run_and_report
    return func()
  File "/fs/pool/pool-krimmel/miniconda3/envs/digress/lib/python3.9/site-packages/hydra/_internal/utils.py", line 458, in <lambda>
    lambda: hydra.run(
  File "/fs/pool/pool-krimmel/miniconda3/envs/digress/lib/python3.9/site-packages/hydra/_internal/hydra.py", line 132, in run
    _ = ret.return_value
  File "/fs/pool/pool-krimmel/miniconda3/envs/digress/lib/python3.9/site-packages/hydra/core/utils.py", line 260, in return_value
    raise self._return_value
  File "/fs/pool/pool-krimmel/miniconda3/envs/digress/lib/python3.9/site-packages/hydra/core/utils.py", line 186, in run_job
    ret.return_value = task_function(task_cfg)
  File "/fs/gpfs41/lv11/fileset01/pool/pool-krimmel/DiGress/src/main.py", line 176, in main
    cfg, _ = get_resume(cfg, model_kwargs)
  File "/fs/gpfs41/lv11/fileset01/pool/pool-krimmel/DiGress/src/main.py", line 31, in get_resume
    model = DiscreteDenoisingDiffusion.load_from_checkpoint(resume, **model_kwargs)
  File "/fs/pool/pool-krimmel/miniconda3/envs/digress/lib/python3.9/site-packages/pytorch_lightning/core/module.py", line 1520, in load_from_checkpoint
    loaded = _load_from_checkpoint(
  File "/fs/pool/pool-krimmel/miniconda3/envs/digress/lib/python3.9/site-packages/pytorch_lightning/core/saving.py", line 62, in _load_from_checkpoint
    checkpoint = pl_load(checkpoint_path, map_location=map_location)
  File "/fs/pool/pool-krimmel/miniconda3/envs/digress/lib/python3.9/site-packages/lightning_fabric/utilities/cloud_io.py", line 51, in _load
    return torch.load(f, map_location=map_location)  # type: ignore[arg-type]
  File "/fs/pool/pool-krimmel/miniconda3/envs/digress/lib/python3.9/site-packages/torch/serialization.py", line 809, in load
    return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
  File "/fs/pool/pool-krimmel/miniconda3/envs/digress/lib/python3.9/site-packages/torch/serialization.py", line 1172, in _load
    result = unpickler.load()
  File "/fs/pool/pool-krimmel/miniconda3/envs/digress/lib/python3.9/pickle.py", line 1212, in load
    dispatch[key[0]](self)
  File "/fs/pool/pool-krimmel/miniconda3/envs/digress/lib/python3.9/pickle.py", line 1717, in load_build
    setstate(state)
  File "/fs/pool/pool-krimmel/miniconda3/envs/digress/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", line 881, in __setstate__
    self.process_group = _get_default_group()
  File "/fs/pool/pool-krimmel/miniconda3/envs/digress/lib/python3.9/site-packages/torch/distributed/distributed_c10d.py", line 707, in _get_default_group
    raise RuntimeError(
RuntimeError: Default process group has not been initialized, please make sure to call init_process_group.

I am using torch==2.0.1+cu118 and pytorch-lightning==2.0.4, as specified in the requirements.

So how are we actually supposed to evaluate the model? I think some instructions in the README would be valuable. Thanks for your help!

Markus28 commented 1 month ago

I found that it is possible to avoid the exception by commenting out the lines here:

https://github.com/cvignac/DiGress/blob/7a36a84103a6e4b732953459515a479f12e8ff3b/src/main.py#L153

However, it is still unclear to me how many samples we should use to faithfully reproduce the results from the paper. The config experiments/planar.yaml says 40, while general_default.yaml says 10k. The former would lead to large variances in the evaluation results, while the latter would take quite long to evaluate (roughly 10 hours on an H100, I believe).

yryMax commented 1 month ago

I get the same problem, it seems that if you train the model using the distributed way(via ddp), you can not load the checkpoint by directly using from_checkpoint and everything must go through the Trainer. line 155 only retrieve the configuration, if you have the exact hydra setup with when you training the model it should be fine.

another comment on this issue: I am wondering is there a stable branch when there is no distributed traininng???I would be more convenient since I want to tweak the samping method and evaluating matrics. Thanks in advance.