atomistic-machine-learning / schnetpack-gschnet

G-SchNet extension for SchNetPack
MIT License
49 stars 8 forks source link

Conditional G-SchNet extension for SchNetPack 2.0 - A generative neural network for 3d molecules

generated molecules

G-SchNet is a generative neural network that samples molecules by sequentially placing atoms in 3d space. It can be trained on data sets of 3d molecules with variable sizes and compositions. The conditional version, cG-SchNet, explicitly takes chemical and structural properties into account to allow for targeted molecule generation.

Here we provide a re-implementation of cG-SchNet using the updated SchNetPack 2.0. Compared to previous releases, SchNetPack changed from batching molecules to batching atoms, effectively removing the need for padding the neural network inputs. G-SchNet greatly benefits from this change in terms of memory requirements, allowing to train models of the same expressivity on GPUs with less VRAM.

The package contains a standardized routine for the filtering of generated molecules for validity, uniqueness, and novelty. Furthermore, we altered a few implementation details to improve scalability and simplify adaptations to custom data sets. Therefore, we recommend this version for applications of G-SchNet to new data sets and further development of the method. For reproduction of the results reported in our publications, please refer to the specific repositories:

Content

Installation

To install schnetpack-gschnet, download this repository and use pip. For example, the following commands will clone the repository into your current working directory and install this package as well as all its dependencies (e.g. SchNetPack 2.0, PyTorch, etc.):

git clone https://github.com/atomistic-machine-learning/schnetpack-gschnet.git
pip install ./schnetpack-gschnet

Command-line interface and configuration

The schnetpack-gschnet package is built on top of schnetpack version 2.0, which is a library for atomistic neural networks with flexible customization and configuration of models and experiments. It is integrated with the PyTorch Lightning learning framework, which takes care of the boilerplate code required for training and provides a standardized, modular interface for incorporating learning tasks and data sets. Moreover, SchNetPack utilizes the hierarchical configuration framework Hydra. This allows to define training runs using YAML config files that can be loaded, composed, and overridden via a powerful command-line interface (CLI).

schnetpack-gschnet is designed to leverage both the PyTorch Lightning integration and the hierarchical Hydra config files. The configs directory contains the YAML files that specify different set-ups for training and generation. It exactly follows the structure of the configs directory from schnetpack. In this way, we can compose a config using files from schnetpack, e.g. for the optimizer, as well as new config files from schnetpack-gschnet, e.g. for the generative model and the data. We recommend to copy the configs directory from schnetpack-gschnet to create a personal resource of config files, e.g. with:

cp -r <path/to/schnetpack-gschnet>/src/schnetpack_gschnet/configs/. <path/to/my_gschnet_configs>

You can customize the existing configs or create new ones in that directory, e.g. to set up training of cG-SchNet with custom conditions. All hyperparameters specified in the YAML files can also directly be set in the CLI when calling the training or generation script. We will explain the most important hyperparameters and config files in the following sections on training and molecule generation. For more details on the strucure of the config, the CLI, and the PyTorch Lightning integration, please refer to the software paper for SchNetPack 2.0 and the documentation of the package.

Model training

If you have copied the configs directory as recommended above, the following call will start a training run in the current working directory:

gschnet_train --config-dir=<path/to/my_gschnet_configs> experiment=gschnet_qm9

The call to the training script requires two arguments, the directory with configs from schnetpack-gschnet and the name of the experiment config you want to run. The experiment config is the most important file for the configuration of the training run. It determines all the hyperparameters, i.e. which model, dataset, optimizer, callbacks etc. to use. We provide three examplary experiment configs:

Experiment name Description
gschnet_qm9 Trains an unconditioned G-SchNet model on the QM9 data set. Mostly follows the experimental setup described in the sections 5, 5.1, and 5.2 of the G-SchNet publication.
gschnet_qm9_comp_relenergy Trains a cG-SchNet model that is conditioned on the atomic composition and the relative atomic energy of molecules on the QM9 data set. Mostly follows the experimental setup described in the section "Discovery of low-energy conformations" of the cG-SchNet publication. Accordingly, a filter is applied to exclude all C7O2H10 conformations from the training and validation splits.
gschnet_qm9_gap_relenergy Trains a cG-SchNet model that is conditioned on the HOMO-LUMO gap and the relative atomic energy of molecules on the QM9 data set. Mostly follows the experimental setup described in the section "Targeting multiple properties: Discovery of low-energy structures with small HOMO-LUMO gap" of the cG-SchNet publication.

Simply change the experiment name in the call to the training script to run any of the three example experiments. PyTorch Lightning will automatically use a GPU for training if it can find one, otherwise the training will run on CPU. Please note that we do not recommend training on the CPU as it will be very slow. In the following sections, we explain the most important settings and how to customize the training run, e.g. to change the target properties for conditioning of the model or to use a custom data set instead of QM9.

Hyperparameters and experiment settings

The three provided experiment configs mostly use the same hyperparameters and only differ in the properties the generative model is conditioned on. In the following table we list the most important hyperparameters and settings including their default value in the three experiments. Here, ${} invokes variable interpolation, which means that the value of another hyperparameter from the config is inserted or a special resolver is used, e.g. to get the current working directory. All these settings can easily be changed in the CLI, e.g. using trainer.accelerator=gpu as additional argument to make sure that training runs on the gpu.

Name Value Description
run.work_dir ${hydra:runtime.cwd} The root directory for running the script. The default value sets it to the current working directory.
run.data_dir ${run.work_dir}/
data
The directory where training data is stored (or will be downloaded to).
run.path ${run.work_dir}/
models/
qm9_${globals.name}
The path where the directory with results of the run (e.g. the checkpoints, the used config etc.) will be stored. For our three experiments, globals.name is set to the properties the model is conditioned on, i.e. no_conditions, comp_relenergy, and gap_relenergy, respectively.
run.id ${globals.id} The name of the directory where the results of the run will be stored. The default setting of globals.id automatically generates a unique identifier or, when available, uses the job id and hostname.
trainer.accelerator auto The type of the accelerator used for training. Supports cpu, gpu, tpu, ipu, and auto. The auto option tries to select the best accelerator automatically. Please beware that training on the CPU is not recommended as it is very slow and that we did not test this package with accelerators other than gpu.
globals.lr 1e-4 The learning rate for the optimizer at the start of the training. We use the Adam optimizer and a reduce on plateau learning rate scheduler, the corresponding settings can be found in the gschnet_task config.
globals.atom_types [1, 6, 7, 8, 9] List of all the atom types contained in molecules of the training data set (expressed as nuclear charges, i.e. here we have H, C, N, O, and F).
globals.origin_type 121 Type used for the auxiliary origin token (cannot be contained in globals.atom_types). The origin token marks the center of mass of training structures and is the starting point for the trajectories of atom placements, i.e. molecules grow around the origin token.
globals.focus_type 122 Type used for the auxiliary focus token (cannot be contained in globals.atom_types). At each step, the focus is aligned with a randomly selected atom and the atom placed next needs to be a neighbor of that focused atom.
globals.stop_type 123 Type used for the stop marker that G-SchNet predicts to mark that it cannot place more atoms in the neighborhood of the current focus (cannot be contained in globals.atom_types).
globals.model_cutoff 10. The cutoff used in the interaction blocks of the SchNet model which extracts features from the intermediate molecular structures.
globals.prediction_cutoff 5. The cutoff used to determine for which atoms around the focus the distance to the new atom is predicted.
globals.placement_cutoff 1.7 The cutoff used to determine which atoms can be placed (i.e. which are neighbors of the focus) when building a trajectory of atom placements for training.
globals.use_covalent_radii True If True, the covalent radii of atom types are additionally used to check whether atoms that are inside the globals.placement_cutoff are neighbors of the focus. We use the covalent radii provided in the ase package and check whether the distance between the focus and another atom is smaller than the sum of the covalent radii for the types of the two atoms scaled by globals.covalent_radii_factor.
globals.covalent_radius_factor 1.1 Scales the sum of the two covalent radius numbers to relax the neighborhood criterion when globas.use_covalent_radii is True.
globals.draw_random_samples 0 The number of atom placements that are randomly drawn per molecule in the batch in each epoch (without replacement, i.e. each step can at most occure once). If 0, all atom placements are used, which means that the number of atoms in a batch scales quadratically with the number of atoms in the training molecules. Therefore, we recommend to select a value larger than 0 to have linear scaling when using data sets with molecules larger than those in QM9.
globals.data_workdir null Path to a directory where the data is copied to for fast access (e.g. local storage of a node when working on a cluster). If null, the data is loaded from its original destination at data.datapath each epoch.
globals.cache_workdir null Path to a directory where the data cache is copied to for fast access (e.g. local storage of a node when working on a cluster). Only used if the results of the neigborhood list are cached. If null, the cached data is loaded from its original destinaltion each epoch.
data.batch_size 5 The number of molecules in training and validation batches. Note that each molecule occurs multiple times in the batch as they are reconstructed in a trajectory of atom placements. The number of times they occur can be limited with globals.draw_random_samples.
data.num_train 50000 The number of molecules for the training split of the data.
data.num_val 5000 The number of molecules for the validation split of the data.
data.datapath ${run.data_dir}/
qm9.db
The path to the training data base file.
data.remove_uncharacterized True Whether the molecules marked as uncharacterized in QM9 are removed from the training data base. Note that the data base has to be re-downloaded and built if this setting is changed.
data.num_workers 6 The number of CPU workers spawned to load training data batches.
data.num_val_workers 4 The number of CPU workers spawned to load validation data batchs.
data.num_test_workers 4 The number of CPU workers spawned to load test data batches.
data.distance_unit Ang The desired distance unit used for the coordinates of atoms. The conversion is automatically done by SchNetPack if the distance unit in the data base is different.
data.property_units.energy_U0 eV The desired unit of the property energy_U0. The conversion is automatically done by SchNetPack if the unit in the data base is different.
data.property_units.gap eV The desired unit of the property gap. The conversion is automatically done by SchNetPack if the unit in the data base is different.
callbacks.early_stopping.
patience
25 The number of epochs after which the training is stopped if the validation loss did not improve. On QM9, cG-SchNet typically trains for 150-250 epochs with these settings.
callbacks.progress_bar.
refresh_rate
100 The number of batches processed before the progress bar is refreshed.

Specifying target properties

In cG-SchNet, the target properties for the conditional distribution are embedded with a neural network block. This is implemented in the class ConditioningModule. First, each property is embedded with an individual network and then all embeddings are concatenated and processed through another block of fully connected layers. In this way, we can use any combination of target properties as conditions as long as we have embedding networks for each individual property. The base class for the embedding networks is ConditionEmbedding and the package contains three subclasses: ScalarConditionEmbedding for scalar-valued properties (e.g. energy, HOMO-LUMO gap, etc.), VectorialConditionEmbedding for vector-valued properties (e.g. fingerprints), and CompositionEmbedding for the composition of molecules.

To specify a set of target properties for an experiment, we set up the corresponding ConditioningModule in a config file. For example, the experiment gschnet_qm9_gap_relenergy, which targets HOMO-LUMO gap and relative atomic energy as conditions, uses the following config:

https://github.com/atomistic-machine-learning/schnetpack-gschnet/blob/dd80e3c6ad04791b1f4dc3820926d4db460522b1/src/schnetpack_gschnet/configs/model/conditioning/gap_relenergy.yaml#L1-L22

Both, the energy and the gap, are embedded using a ScalarConditionEmbedding. It projects a scalar value into vector space using a Gaussian expansion with centers between condition_min and condition_max and a spacing of grid_spacing, i.e. 5 centers in the examples above. Then, it applies a network consisting of three fully connected layers with 64 neurons to extract a vector with 64 features for the corresponding property. The individual vectors of both porperties are concatenated and then processed by five fully connected layers with 128 neurons to obtain a final vector with 128 features that jointly represents energy and gap. Please note that the values for condition_min, condition_max, and grid_spacing of the Gaussian expansion need to be expressed in the units used for the properties. For example, we use eV in this case for both gap and energy as specified in the default data config for the QM9 data set.

An important argument in the config is required_data_properties. It determines additional properties that need to be loaded from the data base for the conditioning. Here, the relative atomic energy is a special case as it is not directly included in the QM9 data base. Instead, we load the total energy at zero Kelvin energy_U0 and compute the relative atomic energy from this. To this end, there exist transforms that are applied to every data point by the data loader. The transforms are specified in the experiment config as part of the data field:

https://github.com/atomistic-machine-learning/schnetpack-gschnet/blob/dd80e3c6ad04791b1f4dc3820926d4db460522b1/src/schnetpack_gschnet/configs/experiment/gschnet_qm9_gap_relenergy.yaml#L45-L70

The corresponding GetRelativeAtomicEnergy transform is defined in lines 52-55. On a side note, we see that transforms take care of all kind of preprocessing tasks, e.g. centering the atom positions, computing neighborhood lists of atoms, and sampling a trajectory of atom placements for the molecule.

As another example, you can refer to the experiment gschnet_qm9_comp_relenergy, which targets the atomic composition and the relative atomic energy as conditions. There, the conditioning config uses the same ScalarConditionEmbedding as before for the relative atomic energy but combines it with a CompositionEmbedding for the atomic composition. Accordingly, the transforms in the experiment config contain GetRelativeAtomicEnergy and GetComposition to compute the relative atomic energy and the atomic composition, respectively.

To summarize, we can specify a set of target properties by adding a conditioning config to <path/to/my_gschnet_configs>/model/conditioning. If the target properties can directly be loaded from the data base, we can use the basic gschnet_qm9 experiment and append our new conditioning config in the CLI to start training. For example, suppose we want to only condition our model on the HOMO-LUMO gap. To this end, we can delete lines 14-22 in the gap_relenergy conditioning config shown above and save the resulting file as <path/to/my_gschnet_configs>/model/conditioning/gap.yaml. Then, the training can be started with:

gschnet_train --config-dir=<path/to/my_gschnet_configs> experiment=gschnet_qm9 model/conditioning=gap

If the target properties are not stored in the data base, it is most convenient to set up an experiment config with suitable transforms that compute them. We directly link the conditioning config in the experiment config by overriding /model/conditioning in the defaults list, as can be seen in the second last line of the following example:

https://github.com/atomistic-machine-learning/schnetpack-gschnet/blob/dd80e3c6ad04791b1f4dc3820926d4db460522b1/src/schnetpack_gschnet/configs/experiment/gschnet_qm9_gap_relenergy.yaml#L3-L16

Then, we only provide the name of the new experiment config in the CLI and do not need an additional argument for the conditioning config. For instance, with the experiment config for a model conditioned on HOMO-LUMO gap and relative atomic energy, the call is:

gschnet_train --config-dir=<path/to/my_gschnet_configs> experiment=gschnet_qm9_gap_relenergy

Using custom data

In order to use custom data, we need to store it in the ASE data base format that is used in schnetpack. The preparation of such a data base takes only a few steps and can be done with the help of schnetpack.data.ASEAtomsData. For example, if you have a function read_molecule that gives the atom positions, atom types, and property values of a molecule (e.g. from xyz, cif, or another db file), you can create a data base in the correct format with the following code:

from ase import Atoms
from schnetpack.data import ASEAtomsData
import numpy as np

mol_list = []
property_list = []
for i in range(n_molecules):
    # get molecule information with your custom read_molecule function
    atom_positions, atom_types, property_values = read_molecule(i)
    # create ase.Atoms object and append it to the list
    mol = Atoms(positions=atom_positions, numbers=atom_types)
    mol_list.append(mol)
    # create dictionary that maps property names to property values and append it to the list
    # note that the property values need to be numpy float arrays (even if they are scalar values)
    properties = {
        "energy": np.array([float(property_values[0])]), 
        "gap": np.array([float(property_values[1])]),
    }
    property_list.append(properties)

# create empty data base with correct format
# make sure to provide the correct units of the positions and properties
custom_dataset = ASEAtomsData.create(
    "/home/user/custom_dataset.db",                              # where to store the data base
    distance_unit="Angstrom",                                    # unit of positions
    property_unit_dict={"energy": "Hartree", "gap": "Hartree"},  # units of properties
)
# write gathered molecules and their properties to the data base
custom_dataset.add_systems(property_list, mol_list)

In the for-loop, we build a list of molecular structures in the form of ase.Atoms objects and a corresponding list of dictionaries containing mappings from property names to property values for each molecule. Having these lists, we can easily create an empty data base in the correct format and store our gathered molecules with functions from schnetpack.data.ASEAtomsData. In this example, we assume that we have energy and gap values in Hartree for each molecule and that the atom positions are give in Angstrom. Of course, the read_molecule function and lines where we specify properties and units need to be adapted carefully to fit your custom data if you use the code from above. To find which units are supported, please check the ASE units module. It includes most common units such as eV, Ha, Bohr, kJ, kcal, mol, Debye, Ang, nm etc.

Once the data is in the required format, we can train G-SchNet. To this end, we provide the gschnet_template experiment config and the custom_data data config:

https://github.com/atomistic-machine-learning/schnetpack-gschnet/blob/dd80e3c6ad04791b1f4dc3820926d4db460522b1/src/schnetpack_gschnet/configs/experiment/gschnet_template.yaml#L1-L31 https://github.com/atomistic-machine-learning/schnetpack-gschnet/blob/dd80e3c6ad04791b1f4dc3820926d4db460522b1/src/schnetpack_gschnet/configs/data/custom_data.yaml#L1-L28

Here, arguments specific to the custom data set are left with ???, wich means that they need to be specified in the CLI when using the configs. For example, assume we have stored a data base with 10k molecules consisting of carbon, oxygen, and hyrogen at _/home/user/customdataset.db. Then, we can start the training process with the following call:

gschnet_train --config-dir=<path/to/my_gschnet_configs> experiment=gschnet_template data.datapath=/home/user/custom_dataset.db data.batch_size=10 data.num_train=5000 data.num_val=2000 globals.name=custom_data globals.id=first_run globals.model_cutoff=10 globals.prediction_cutoff=5 globals.placement_cutoff=1.7 globals.atom_types="[1, 6, 8]"

Alternatively, you can copy the configs and fill in the left-out arguments in the files. Please choose them according to your data, e.g the placement_cutoff should be slightly larger than the typical bond lengths. Systems with periodic boundary conditions are currently not supported. By default, a model without target properties is trained. To train a model with conditions, you need to create a conditioning config as explained in the previous section. To convert the units of target properties upon loading, add property_units to the data config (cf. the QM9 data config). Also, note our hints on scaling up the training in the following section if your data set includes large molecules.

Scaling up the training

The QM9 data set, which is used in the example experiments, contains only small organic compounds. Therefore, the default settings in those configs might lead to excessive memory and runtime requirements when using other data sets. In the following, we shed light on the most important settings and tweaks to use G-SchNet with molecules much larger than those in QM9.

1. Set draw_random_samples > 0

In each epoch, the data loader reads each molecule from the data base and samples a trajectory of atom placements, i.e. it starts with the first atom, then selects a second atom, a third atom, and so on. When predicting the position of the second atom, the model uses the structure consisting of only the first atom, when predicting the position of the third atom, the model uses the structure consisting of the two first atoms etc. Therefore, the partial molecule occurs many times in a batch, once for every prediction step. Accordingly, the number of atoms in a batch scales quadratically with the number of atoms in the training structures. For QM9 this is feasible but it becomes problematic when working with larger structures. To remedy this effect, we can restrict the batch to contain only a fixed number of prediction steps per molecule, which confines it to scale linearly. The config setting globals.draw_random_samples determines the number of steps that are drawn randomly from the whole trajectory. For data sets with large molecules, we recommend to set it to a small number, e.g. five. Setting it to zero restores the default behavior, which adds all steps to the batch. Please note that the prediction steps for the validation batch will also be drawn randomly, which causes the validation loss to vary slightly even if the model does not change.

2. Choose suitable cutoffs

Two other parameters that influence the scaling of the method are the prediction cutoff and the model cutoff. The network will only predict the distances of the new atom to atoms that are closer to the focus than the prediction cutoff. This limits the number of distributions predicted at each step and therefore improves the scaling compared to previous implementations, where the distances to all preceding atoms were predicted. Choosing a larg prediction cutoff will lead to higher flexibility of the network but also higher memory consumption and potentially redundant predictions. A very small prediction cutoff might hurt the performance of the model. The model cutoff determines the neighbors that exchange messages when extracting atom-wise features from the molecular structure. A large model cutoff can be interpreted as increasing the receptive field of an atom but comes with higher computational costs. From our experience, we recommend to set values between 5 and 10 Angstrom, with the model cutoff larger or equal to the prediction cutoff. The corresponding config settings are globals.prediction_cutoff and globals.model_cutoff. For QM9, we use 10 Angstrom for the model cutoff and 5 Angstrom for the prediction cutoff, which should also be a reasonable starting point for other data sets.

3. Use caching of neighborlists

The data loading and preprocessing can become quite expensive for larger structures, especially the computation of the neighborlists of all atoms. The batches are loaded in parallel to neural network computations by separate CPU threads. If the GPU has a low utilization during training, it might help to use more workers to reduce the waiting time of the GPU. The number of workers for training, validation, and test data can be set with data.num_workers, data.num_val_workers, and data.num_test_workers, respectively. However, for larger molecules we generally recommend to cache the computed neighborlists to reduce the load of the workers. To this end, the package contains the GeneralCachedNeighborList transform. It can be incorporated in the experiment config by wrapping ConditionalGSchNetNeighborList in the list of transforms. For custom data, we have a config file called custom_data_cached with the corresponding setup of transforms:

https://github.com/atomistic-machine-learning/schnetpack-gschnet/blob/dd80e3c6ad04791b1f4dc3820926d4db460522b1/src/schnetpack_gschnet/configs/data/custom_data_cached.yaml#L17-L27

To use this data config, add data=custom_data_cached to the training call. The cache will be stored at cache_path and deleted after training unless keep_cache is set to True. As the neighborlist results depend on the chosen cutoffs, do not re-use the cache from previous runs unless you are 100% sure that all settings are identical.

4. Use working directories for data and caching

Another reason for low GPU utilization can be slow reading speed of the data and cache. For example, when running code on a cluster, you often have slow, shared storage and faster, local node storage. Therefore, you can choose working directories for the data and cache by setting globals.data_workdir and globals.cache_workdir to directories on the fast storage. If you do so, the data and cache will be copied to these locations and then read from there for the training run. They are automatically deleted if the run finishes without errors.

Molecule generation

After training a model, you can generate molecules from the CLI with the generation script. Four parameters are required. These are root directory of the trained model (i.e. the directory containing the files best_model, cli.log, config.yaml etc.), the number of molecules that shall be generated, the size of batches for generation, and a maximum number of atoms that the model is allowed to sample per molecule:

gschnet_generate modeldir=<path/to/trained/model> n_molecules=1000 batch_size=500 max_n_atoms=120

The generated molecules are stored in an ASE data base at <modeldir>/generated_molecules/. For models trained with conditions, target values for all properties that were used have to be specified. For example, for a model trained with the gschnet_qm9_gap_relenergy config, both a target HOMO-LUMO gap and relative atomic energy have to be set. This can be done by appending the following arguments to the CLI call:

++conditions.gap=4.0 ++conditions.relative_atomic_energy=-0.2

Here the ++ is needed to append new arguments to the config (as opposed to setting new values for existing config entries). Note that the names of the target properties have to correspond to the condition_name specified in the conditioning configs of the trained model. That is why we use gap and relative_atomic_energy here, as specified in lines 6 and 15 of the conditioning config. In models conditioned on the atomic composition, the corresponding property name is automatically set to composition.

In the following table, we list all settings for the generation script and their default values (where ??? marks entries that have no default value and thus are required when calling the script). All settings can directly be set in the CLI, e.g. add view_molecules=True to display the molecules after generation.

Name Value Description
modeldir ??? The directory where the trained model is stored (the directory containing the files best_model, cli.log, config.yaml etc.).
n_molecules ??? The number of molecules that shall be generated. Note that the number of molecules in the resulting data base can be lower as failed generation attempts are not stored, i.e. where the model has not finished generation after placing max_n_atoms atoms.
batch_size ??? The number of molecules generated in one batch. Can be significantly larger than training batches. Use large batches if possible and decrease the batch size if your GPU runs out of memory.
max_n_atoms ??? The maximum number of atoms the model is allowed to place. If it has not finished after placing this many atoms, it will discard the structure as a failed generation attempt. Usually, you set this larger than the largest molecules in the training data set.
outputfile null Name of the data base where generated molecules are stored. The data base will always be stored at <path/to/trained/model>/generated_molecules/. If null, the script will automatically assign a number to the data base (it starts to count from 1 and increases the count by one if a data base with the number already exists).
use_gpu True Set True to run generation on the GPU.
view_molecules False Set True to automatically open a pop-up window with visualizations of all generated structures (uses the ASE package for visualization).
grid_distance_min 0.7 The minimum distance between a new atom and the focus atom. Determines the extent of the 3d grid together with the placement_cutoff used during model training, which sets the maximum distance between new atom and focus.
grid_spacing 0.05 The size of a bin in the 3d grid (i.e. a value of 0.05 means each bin has a size of 0.05x0.05x0.05).
temperature_term 0.1 The temperature term in the normalization of the 3d grid probability. A smaller value leads to more pronounced peaks whereas a larger value increases randomness by smoothing the distribution.

Filtering molecules

After generation, a common postprocessing step is to filter the molecules for validity, uniqueness, and novelty. There exist many ways to do so and thus statistics reported in publications are often not directly comparable. In schnetpack-gschnet, we provide a standardize script that filters a data base of molecules using the implementation of xyz2mol available in RDKit. The translation of the structures, and therefore the result of the analysis, depends on the exact version of RDKit and the options used when calling the script. Therefore, if you employ the script for analysis in a publication, please report the installed version of RDKit, the version of schnetpack-gschnet, and the exact options used in the call. This facilitates the reproduction of the results and enables fair comparisons.

The script translates the generated 3d structures into molecular graphs in the form of canonical, isomeric SMILES strings. Then, based on the string, it checks the valency of all atoms given permissible charges of the system. It also uses the strings to check the uniqueness of molecules and to compare them to structures in the training data set. The script is called check_validity.py and a typical call is:

python <path/to/schnetpack-gschnet>/src/scripts/check_validity.py <path-to-db> --compute_uniqueness --compare_db_path <path-to-training-db> --compare_db_split_path <modeldir>/split.npz --ignore_enantiomers --timeout 2 --results_db_path auto --results_db_flags unique

Here, we provide the paths to the training data base and the split file to be able to identify which generated molecules match training, validation, and test structures. The routine will store a new data base file containing only valid and uniqe molecules into the same directory as the input data base, using the same name but appending _filtered. The usage of --ignore_enantiomers is optional and it determines that mirror-image stereoisomers are treated as identical molecules. The timeout of 2 seconds interrupts the translation from 3d structure to molecular graph if it takes too long, which can happen for large molecules. The number of molecules that could not be translated due to the timeout is printed by the script. Please note that the process is not always immediately interrupted and can therefore take longer than the specified timeout per structure. By default, molecules are only allowed to have a total charge of zero and cannot contain charged fragments. Since charged fragments or other total charges might be permissible depending on the training data set, these settings can be adjusted with, e.g., --allow_charged_fragments --allowed_charges 0 -1 to allow charged fragments and total charges of zero and minus one. To display all available options, call the script with --help. Depending on the required analysis, the call can of course be simplified by removing options. For example, to only identify and remove invalid structures, use:

python <path/to/schnetpack-gschnet>/src/scripts/check_validity.py <path-to-db> --timeout 2 --results_db_path auto

This call can also be used with training data bases, e.g. to compute the percentage of training structures that can be validated with the method for reference or to filter the structures prior to training.

Additional information

FAQ and troubleshooting

1. Can I restart training after a run crashed or timed out?

Yes, if you start a run with an existing run.id, the training will automatically be resumed from the last stored checkpoint. The run.id is the name of the folder where the model, logs, config etc. of a run are stored. In our example configs, run.id is automatically set to a unique identifier if you do not provide it manually. Accordingly, every time a run is started, a new folder is created and a new model is trained. However, if you set run.id=<your_name> in your CLI call the training will be resumed if there already exists a folder called <your_name> (otherwise, a fresh run with that name will be initialized).

2. Data preprocessing is getting stuck or crashing

In earlier versions, we used python's builtin multiprocessing package to speed up the data setup phase before training starts. On some machines, this has led to the process getting stuck or crashes with error messages related to pickle. However, we transitioned to using a pytorch dataloader with multiple workers for this task which is 1. a bit faster and 2. should run on all systems. Therefore, please make sure that you are using schnetpack-gschnet>=1.0.0 if the preprocessing is getting stuck or crashing. If the problem still persists, please open an issue. As a short-term fix, multi-processing can be deactivated for the data setup by appending +data.num_preprocessing_workers=0 to the training call.

Changes in this implementation

Compared to previous implementations of G-SchNet, we improved the scalability and simplified the adaptation to custom data sets. The changes we made mainly concern the preparation of data and the reconstruction of the 3d positional distribution.

Accordingly, in comparison to previous implementations where G-SchNet had only a single model cutoff that determined which atoms are exchanging messages in the SchNet interaction blocks, this version has three cutoffs as hyperparameters, namely the model cutoff, the prediction cutoff, and the placement cutoff.

Citation

If you use G-SchNet in your research, please cite the corresponding publications:

N.W.A. Gebauer, M. Gastegger, S.S.P. Hessmann, K.-R. Müller, and K.T. Schütt. Inverse design of 3d molecular structures with conditional generative neural networks. Nature Communications 13, 973 (2022). https://doi.org/10.1038/s41467-022-28526-y

N. Gebauer, M. Gastegger, and K. Schütt. Symmetry-adapted generation of 3d point sets for the targeted discovery of molecules. In H. Wallach, H. Larochelle, A. Beygelzimer, F. d'Alché-Buc, E. Fox, and R. Garnett, editors, Advances in Neural Information Processing Systems 32, 7566–7578. Curran Associates, Inc. (2019). http://papers.nips.cc/paper/8974-symmetry-adapted-generation-of-3d-point-sets-for-the-targeted-discovery-of-molecules.pdf

K.T. Schütt, S.S.P. Hessmann, N.W.A. Gebauer, J. Lederer, and M. Gastegger. SchNetPack 2.0: A neural network toolbox for atomistic machine learning. The Journal of Chemical Physics 158, 144801 (2023). https://doi.org/10.1063/5.0138367

@Article{gebauer2022inverse,
    author = {Gebauer, Niklas W. A. and Gastegger, Michael and Hessmann, Stefaan S. P. and M{\"u}ller, Klaus-Robert and Sch{\"u}tt, Kristof T.},
    title = {Inverse design of 3d molecular structures with conditional generative neural networks},
    journal = {Nature Communications},
    year = {2022},
    volume = {13},
    number = {1},
    pages = {973},
    issn = {2041-1723},
    doi = {10.1038/s41467-022-28526-y},
    url = {https://doi.org/10.1038/s41467-022-28526-y},
}
@incollection{gebauer2019symmetry,
    author = {Gebauer, Niklas and Gastegger, Michael and Sch\"{u}tt, Kristof},
    title = {Symmetry-adapted generation of 3d point sets for the targeted discovery of molecules},
    booktitle = {Advances in Neural Information Processing Systems 32},
    editor = {H. Wallach and H. Larochelle and A. Beygelzimer and F. d\textquotesingle Alch\'{e}-Buc and E. Fox and R. Garnett},
    year = {2019},
    pages = {7566--7578},
    publisher = {Curran Associates, Inc.},
    url = {http://papers.nips.cc/paper/8974-symmetry-adapted-generation-of-3d-point-sets-for-the-targeted-discovery-of-molecules.pdf},
}
@article{schutt2023schnetpack,
    author = {Sch{\"u}tt, Kristof T. and Hessmann, Stefaan S. P. and Gebauer, Niklas W. A. and Lederer, Jonas and Gastegger, Michael},
    title = "{SchNetPack 2.0: A neural network toolbox for atomistic machine learning}",
    journal = {The Journal of Chemical Physics},
    volume = {158},
    number = {14},
    pages = {144801},
    year = {2023},
    month = {04},
    issn = {0021-9606},
    doi = {10.1063/5.0138367},
    url = {https://doi.org/10.1063/5.0138367},
    eprint = {https://pubs.aip.org/aip/jcp/article-pdf/doi/10.1063/5.0138367/16825487/144801\_1\_5.0138367.pdf},
}

How does cG-SchNet work?

cG-SchNet is an autoregressive neural network. It builds 3d molecules by placing one atom after another in 3d space. To this end, the joint distribution of all atoms is factorized into single steps, where the position and type of the new atom depends on the preceding atoms (Figure a). The model also processes conditions, i.e. values of target properties, which enable it to learn a conditional distribution of molecular structures. This distribution allows targeted sampling of molecules that are highly likely to exhibit specified conditions (see e.g. the distribution of the polarizability of molecules generated with cG-SchNet using five different target values in Figure b). The type and absolute position of new atoms are sampled successively, where the probability of the positions is apporximated from predicted pairwise distances to preceding atoms. In order to improve the accuracy of the approximation and steer the generation process, the network uses two auxiliary tokens, the focus and the origin. The new atom always has to be a neighbor of the focus and the origin marks the supposed center of mass of the final structure. A scheme explaining the generation procedure can be seen in Figure c. It uses 2d positional distributions for visualization purposes. For more details, please refer to the cG-SchNet publication.

generated molecules