ehoogeboom / e3_diffusion_for_molecules

MIT License
408 stars 110 forks source link

eval_sample.py looking for a model file that doesn't exist #6

Closed Dunni3 closed 1 year ago

Dunni3 commented 1 year ago

Hi, I ran into a minor issue when trying to sample molecules from the trained edm_qm9 model. In brief, I think the file outputs/edm_qm9/flow_ema.npy might be a misnamed or be a misplaced file, or a file leftover from a time when your code used a different naming convention for output files. I'll explain in more detail below:

I ran the the following command (taken from README.md) for sampling molecules:

python eval_sample.py --model_path outputs/edm_qm9 --n_samples 1

Note that I am trying to sample molecules with one of the trained models provided in this repository.

This results in the following missing file exception:

Exception has occurred: FileNotFoundError
[Errno 2] No such file or directory: 'outputs/edm_qm9/generative_model_ema.npy'
  File ".../e3_diffusion_for_molecules/eval_sample.py", line 137, in main
    flow_state_dict = torch.load(join(eval_args.model_path, fn),
  File "...e3_diffusion_for_molecules/eval_sample.py", line 164, in <module>
    main()

It seems that the arguments contained in args.pickle are causing the function eval_sample.main to look for a file in the repository: outputs/edm_qm9/generative_model_ema.npy

When I train a diffusion model myself on the qm9 dataset, using the command given in the readme, the training code produces a file named generative_model_ema.npy, and I am able to run eval_sample.py successfully when pointing it to the args/model file for the model I trained.

There seems to be a model file in the repository outputs/edm_qm9/flow_ema.npy. Is this file perhaps misnamed? or from another experiment?. I thought maybe this file needs to be changed.

If I'm correct, I figured this might be an important update to make.

PS: Congrats on putting out this awesome work and thank you for making it so accessible!

ehoogeboom commented 1 year ago

Hi, good catch thank you for pointing that out. As you suspected, the "flow_ema.npy" should indeed be named "generative_model_ema.npy". I updated the repository :)