Closed DDDIGHE closed 10 months ago
Hi @DDDIGHE ,
yes, exactly. The name of the property alpha
in the downloaded QM9 database is isotropic_polarizability
. You need to define a conditioning network at configs/model/conditioning/isotropic_polarizability.yaml
, where you take into account the range of the property in the database (which is given in Bohr^3 in QM9). It could for example look like this:
_target_: schnetpack_gschnet.ConditioningModule
n_features: 128
n_layers: 5
condition_embeddings:
- _target_: schnetpack_gschnet.ScalarConditionEmbedding
condition_name: isotropic_polarizability
condition_min: 40
condition_max: 100
grid_spacing: 15
n_features: 64
n_layers: 3
required_data_properties:
- isotropic_polarizability
This specifies a network that uses Gaussians centered at 40, 55, 70, 85, 100
to embed the property isotropic_polarizability
in vector space and then transform it through several fully connected layers.
Then you need to define an experiment config at configs/experiment/gschnet_qm9_isotropic_polarizability.yaml
that uses this conditioning network. We can for example adapt the config gschnet_qm9
:
# @package _global_
defaults:
- override /data: gschnet_qm9
- override /task: gschnet_task
- override /callbacks:
- checkpoint
- earlystopping
- lrmonitor
- progressbar
- modelsummary
- override /model: gschnet
- override /model/conditioning: isotropic_polarizability
run:
path: ${run.work_dir}/models/qm9_${globals.name}
id: ${globals.id}
globals:
model_cutoff: 10.
prediction_cutoff: 10.
placement_cutoff: 1.7
use_covalent_radii: True
covalent_radius_factor: 1.1
atom_types: [1, 6, 7, 8, 9]
origin_type: 121
focus_type: 122
stop_type: 123
lr: 1e-4
draw_random_samples: 0
name: isotropic_polarizability
id: ${oc.env:SLURM_JOBID,${uuid:1}}_${oc.env:HOSTNAME,""}
data_workdir: null
cache_workdir: null
callbacks:
early_stopping:
patience: 25
progress_bar:
refresh_rate: 100
model_summary:
max_depth: -1
data:
batch_size: 5
num_train: 50000
num_val: 5000
transforms:
- _target_: schnetpack.transform.SubtractCenterOfMass
- _target_: schnetpack_gschnet.transform.OrderByDistanceToOrigin
- _target_: schnetpack_gschnet.transform.ConditionalGSchNetNeighborList
model_cutoff: ${globals.model_cutoff}
prediction_cutoff: ${globals.prediction_cutoff}
placement_cutoff: ${globals.placement_cutoff}
environment_provider: matscipy
use_covalent_radii: ${globals.use_covalent_radii}
covalent_radius_factor: ${globals.covalent_radius_factor}
- _target_: schnetpack_gschnet.transform.BuildAtomsTrajectory
centered: True
origin_type: ${globals.origin_type}
focus_type: ${globals.focus_type}
stop_type: ${globals.stop_type}
draw_random_samples: ${globals.draw_random_samples}
sort_idx_i: False
- _target_: schnetpack.transform.CastTo32
The only changes here are - override /model/conditioning: isotropic_polarizability
, globals: name: isotropic_polarizability
.
Afterwards, you can start training with:
python <path/to/schnetpack-gschnet>/src/scripts/train.py --config-dir=<path/to/my_gschnet_configs> experiment=gschnet_qm9_isotropic_polarizability
Since there are only a few changes, you can also directly use the gschnet_qm9.yaml
experiment config and do the changes in the command line:
python <path/to/schnetpack-gschnet>/src/scripts/train.py --config-dir=<path/to/my_gschnet_configs> experiment=gschnet_qm9 model/conditioning=isotropic_polarizability globals.name=isotropic_polarizability
However, if you start changing more settings, it's easier to specify the new experiment config.
Hope this helps! Best regards Niklas
Thank you, Niklas! You've answered my questions well, and I have successfully started training.
Dear author, I have consulted the relevant documentation, but I am still not very clear on how to train a CGSchNet based on alpha condition. Do I need to write two separate config files in the directories schnetpack-gschnet/src/schnetpack_gschnet/configs/model/conditioning/ and schnetpack-gschnet/src/schnetpack_gschnet/configs/experiment/?