atomistic-machine-learning / schnetpack-gschnet

G-SchNet extension for SchNetPack
MIT License
49 stars 8 forks source link

Train the model to be conditioned on alpha #11

Closed DDDIGHE closed 10 months ago

DDDIGHE commented 10 months ago

Dear author, I have consulted the relevant documentation, but I am still not very clear on how to train a CGSchNet based on alpha condition. Do I need to write two separate config files in the directories schnetpack-gschnet/src/schnetpack_gschnet/configs/model/conditioning/ and schnetpack-gschnet/src/schnetpack_gschnet/configs/experiment/?

NiklasGebauer commented 10 months ago

Hi @DDDIGHE ,

yes, exactly. The name of the property alpha in the downloaded QM9 database is isotropic_polarizability. You need to define a conditioning network at configs/model/conditioning/isotropic_polarizability.yaml, where you take into account the range of the property in the database (which is given in Bohr^3 in QM9). It could for example look like this:

_target_: schnetpack_gschnet.ConditioningModule
n_features: 128
n_layers: 5
condition_embeddings:
  - _target_: schnetpack_gschnet.ScalarConditionEmbedding
    condition_name: isotropic_polarizability
    condition_min: 40
    condition_max: 100
    grid_spacing: 15
    n_features: 64
    n_layers: 3
    required_data_properties:
      - isotropic_polarizability

This specifies a network that uses Gaussians centered at 40, 55, 70, 85, 100 to embed the property isotropic_polarizability in vector space and then transform it through several fully connected layers.

Then you need to define an experiment config at configs/experiment/gschnet_qm9_isotropic_polarizability.yaml that uses this conditioning network. We can for example adapt the config gschnet_qm9:

# @package _global_

defaults:
  - override /data: gschnet_qm9
  - override /task: gschnet_task
  - override /callbacks:
      - checkpoint
      - earlystopping
      - lrmonitor
      - progressbar
      - modelsummary
  - override /model: gschnet
  - override /model/conditioning: isotropic_polarizability

run:
  path: ${run.work_dir}/models/qm9_${globals.name}
  id: ${globals.id}

globals:
  model_cutoff: 10.
  prediction_cutoff: 10.
  placement_cutoff: 1.7
  use_covalent_radii: True
  covalent_radius_factor: 1.1
  atom_types: [1, 6, 7, 8, 9]
  origin_type: 121
  focus_type: 122
  stop_type: 123
  lr: 1e-4
  draw_random_samples: 0
  name: isotropic_polarizability
  id: ${oc.env:SLURM_JOBID,${uuid:1}}_${oc.env:HOSTNAME,""}
  data_workdir: null
  cache_workdir: null

callbacks:
  early_stopping:
    patience: 25
  progress_bar:
    refresh_rate: 100
  model_summary:
    max_depth: -1

data:
  batch_size: 5
  num_train: 50000
  num_val: 5000
  transforms:
    - _target_: schnetpack.transform.SubtractCenterOfMass
    - _target_: schnetpack_gschnet.transform.OrderByDistanceToOrigin
    - _target_: schnetpack_gschnet.transform.ConditionalGSchNetNeighborList
      model_cutoff: ${globals.model_cutoff}
      prediction_cutoff: ${globals.prediction_cutoff}
      placement_cutoff: ${globals.placement_cutoff}
      environment_provider: matscipy
      use_covalent_radii: ${globals.use_covalent_radii}
      covalent_radius_factor: ${globals.covalent_radius_factor}
    - _target_: schnetpack_gschnet.transform.BuildAtomsTrajectory
      centered: True
      origin_type: ${globals.origin_type}
      focus_type: ${globals.focus_type}
      stop_type: ${globals.stop_type}
      draw_random_samples: ${globals.draw_random_samples}
      sort_idx_i: False
    - _target_: schnetpack.transform.CastTo32

The only changes here are - override /model/conditioning: isotropic_polarizability, globals: name: isotropic_polarizability.

Afterwards, you can start training with:

python <path/to/schnetpack-gschnet>/src/scripts/train.py --config-dir=<path/to/my_gschnet_configs> experiment=gschnet_qm9_isotropic_polarizability

Since there are only a few changes, you can also directly use the gschnet_qm9.yaml experiment config and do the changes in the command line:

python <path/to/schnetpack-gschnet>/src/scripts/train.py --config-dir=<path/to/my_gschnet_configs> experiment=gschnet_qm9 model/conditioning=isotropic_polarizability globals.name=isotropic_polarizability

However, if you start changing more settings, it's easier to specify the new experiment config.

Hope this helps! Best regards Niklas

DDDIGHE commented 10 months ago

Thank you, Niklas! You've answered my questions well, and I have successfully started training.