atomistic-machine-learning / schnetpack-gschnet

G-SchNet extension for SchNetPack
MIT License
45 stars 8 forks source link

Train the model to be conditioned on alpha #11

Closed DDDIGHE closed 8 months ago

DDDIGHE commented 8 months ago

Dear author, I have consulted the relevant documentation, but I am still not very clear on how to train a CGSchNet based on alpha condition. Do I need to write two separate config files in the directories schnetpack-gschnet/src/schnetpack_gschnet/configs/model/conditioning/ and schnetpack-gschnet/src/schnetpack_gschnet/configs/experiment/?

NiklasGebauer commented 8 months ago

Hi @DDDIGHE ,

yes, exactly. The name of the property alpha in the downloaded QM9 database is isotropic_polarizability. You need to define a conditioning network at configs/model/conditioning/isotropic_polarizability.yaml, where you take into account the range of the property in the database (which is given in Bohr^3 in QM9). It could for example look like this:

_target_: schnetpack_gschnet.ConditioningModule
n_features: 128
n_layers: 5
condition_embeddings:
  - _target_: schnetpack_gschnet.ScalarConditionEmbedding
    condition_name: isotropic_polarizability
    condition_min: 40
    condition_max: 100
    grid_spacing: 15
    n_features: 64
    n_layers: 3
    required_data_properties:
      - isotropic_polarizability

This specifies a network that uses Gaussians centered at 40, 55, 70, 85, 100 to embed the property isotropic_polarizability in vector space and then transform it through several fully connected layers.

Then you need to define an experiment config at configs/experiment/gschnet_qm9_isotropic_polarizability.yaml that uses this conditioning network. We can for example adapt the config gschnet_qm9:

# @package _global_

defaults:
  - override /data: gschnet_qm9
  - override /task: gschnet_task
  - override /callbacks:
      - checkpoint
      - earlystopping
      - lrmonitor
      - progressbar
      - modelsummary
  - override /model: gschnet
  - override /model/conditioning: isotropic_polarizability

run:
  path: ${run.work_dir}/models/qm9_${globals.name}
  id: ${globals.id}

globals:
  model_cutoff: 10.
  prediction_cutoff: 10.
  placement_cutoff: 1.7
  use_covalent_radii: True
  covalent_radius_factor: 1.1
  atom_types: [1, 6, 7, 8, 9]
  origin_type: 121
  focus_type: 122
  stop_type: 123
  lr: 1e-4
  draw_random_samples: 0
  name: isotropic_polarizability
  id: ${oc.env:SLURM_JOBID,${uuid:1}}_${oc.env:HOSTNAME,""}
  data_workdir: null
  cache_workdir: null

callbacks:
  early_stopping:
    patience: 25
  progress_bar:
    refresh_rate: 100
  model_summary:
    max_depth: -1

data:
  batch_size: 5
  num_train: 50000
  num_val: 5000
  transforms:
    - _target_: schnetpack.transform.SubtractCenterOfMass
    - _target_: schnetpack_gschnet.transform.OrderByDistanceToOrigin
    - _target_: schnetpack_gschnet.transform.ConditionalGSchNetNeighborList
      model_cutoff: ${globals.model_cutoff}
      prediction_cutoff: ${globals.prediction_cutoff}
      placement_cutoff: ${globals.placement_cutoff}
      environment_provider: matscipy
      use_covalent_radii: ${globals.use_covalent_radii}
      covalent_radius_factor: ${globals.covalent_radius_factor}
    - _target_: schnetpack_gschnet.transform.BuildAtomsTrajectory
      centered: True
      origin_type: ${globals.origin_type}
      focus_type: ${globals.focus_type}
      stop_type: ${globals.stop_type}
      draw_random_samples: ${globals.draw_random_samples}
      sort_idx_i: False
    - _target_: schnetpack.transform.CastTo32

The only changes here are - override /model/conditioning: isotropic_polarizability, globals: name: isotropic_polarizability.

Afterwards, you can start training with:

python <path/to/schnetpack-gschnet>/src/scripts/train.py --config-dir=<path/to/my_gschnet_configs> experiment=gschnet_qm9_isotropic_polarizability

Since there are only a few changes, you can also directly use the gschnet_qm9.yaml experiment config and do the changes in the command line:

python <path/to/schnetpack-gschnet>/src/scripts/train.py --config-dir=<path/to/my_gschnet_configs> experiment=gschnet_qm9 model/conditioning=isotropic_polarizability globals.name=isotropic_polarizability

However, if you start changing more settings, it's easier to specify the new experiment config.

Hope this helps! Best regards Niklas

DDDIGHE commented 8 months ago

Thank you, Niklas! You've answered my questions well, and I have successfully started training.