Closed matudor closed 2 years ago
Hi @matudor,
Thank you for your interest.
(1) We have to point out that EBMs are usually quite unstable to train. Do you follow our example code, including the hyperparameters, for training? EBMs are really sensitive to the setting of Langevin dynamics and the generation performance of consecutive epochs might fluctuate.
(2) I can load our trained model on my side. The parameters in the energy function are shown below. It looks like you have the additional .linear_node. parameters. Did you set the add_self=True for the energy function? In contrast, we used add_self=False for the trained model.
Hi @mengliu1998, thanks for your response. (2): you're right, that was the issue. I thought I was using a 'clean' version of the notebook but I inadvertently used one that I had already played with. Changing the add_self allowed me to load the model. Regarding (1): I have tried the code in the examples both with and without modification and haven't been able to get a good model. Any suggestions on reasonable ranges of Langevin parameters to try, and is there any dependence of these on train set size, complexity, etc? Aside from that, is there anything aside from multiple restarts with different random seeds that can be used to achieve convergence? Thanks!
Hi @matudor,
Have you checked the generated molecules of the consecutive epochs? We observed that the generation performance of consecutive epochs might fluctuate. Thanks!
Hi, I enjoyed the GraphEBM preprint, but have had a couple issues applying the code.
When trying to train the models from scratch I invariably get mode collapse, with all generated molecules identical (and unrealistic, below), any suggestions to get model to train successfully?
when trying to use the downloaded GraphEBM_Zinc* weights from DIG_storage, I get an incompatibility with the example model structure:
RuntimeError: Error(s) in loading state_dict for EnergyFunc: Missing key(s) in state_dict: "graphconv1.linear_node.bias", "graphconv1.linear_node.weight_orig", "graphconv1.linear_node.weight", "graphconv1.linear_node.weight_u", "graphconv.0.linear_node.bias", "graphconv.0.linear_node.weight_orig", "graphconv.0.linear_node.weight", "graphconv.0.linear_node.weight_u", "graphconv.1.linear_node.bias", "graphconv.1.linear_node.weight_orig", "graphconv.1.linear_node.weight", "graphconv.1.linear_node.weight_u".