QuantumLab-ZY / HamGNN

An E(3) equivariant Graph Neural Network for predicting electronic Hamiltonian matrix
GNU General Public License v3.0
61 stars 15 forks source link

model can not save to ckpt #1

Open Youhaojen opened 1 year ago

Youhaojen commented 1 year ago

Thanks for your great work. I try to use the training data (silicon) from your paper for training a model. However, it can not save ckpt file. The error shows The network: HamGNN_pre is not yet supported!

Following is the config.yaml.

dataset_params:
  batch_size: 1
  split_file: null
  test_ratio: 0.1
  train_ratio: 0.8
  val_ratio: 0.1
  graph_data_path: ./graph_data_Si/ # Directory where graph_data.npz is located

losses_metrics:
  losses:
  - loss_weight: 1.0
    metric: mae
    prediction: hamiltonian
    target: hamiltonian
  metrics:
  - metric: mae
    prediction: hamiltonian
    target: hamiltonian

# Generally, the optim_params module only needs to set the initial learning rate (lr)
optim_params:
  lr: 0.001
  lr_decay: 0.5
  lr_patience: 5
  gradient_clip_val: 0.0
  max_epochs: 6
  min_epochs: 1
  stop_patience: 3

output_nets:
  output_module: HamGNN_out
  HamGNN_out:
    ham_only: true # true: Only the Hamiltonian H is computed; 'false': Fit both H and S
    ham_type: openmx # openmx: fit openmx Hamiltonian; abacus: fit abacus Hamiltonian
    nao_max: 14 # The maximum number of atomic orbitals in the data set, which can be 14, 19 or 27
    add_H0: true # Generally true, the complete Hamiltonian is predicted as the sum of H_scf plus H_nonscf (H0)
    symmetrize: true # if set to true, the Hermitian symmetry constraint is imposed on the Hamiltonian
    calculate_band_energy: false # Whether to calculate the energy bands to train the model
    num_k: 5 # When calculating the energy bands, the number of K points to use
    band_num_control: null # `dict`: controls how many orbitals are considered for each atom in energy bands; `int`: [vbm-num, vbm+num]; `null`: all bands
    k_path: null # `auto`: Automatically determine the k-point path; `null`: random k-point path; `list`: list of k-point paths provided by the user
    soc_switch: false # if true, fit the SOC Hamiltonian
    nonlinearity_type: norm # norm or gate

profiler_params:
  progress_bar_refresh_rat: 1
  train_dir: ./graph_data_Si #The folder for saving training information and prediction results. This directory can be read by tensorboard to monitor the training process.

representation_nets:
  # Network parameters usually do not need to be changed.
  HamGNN_pre:
    cutoff: 20.0
    resnet: True
    cutoff_func: cos
    edge_sh_normalization: component
    edge_sh_normalize: true
    ######## Irreps set 1 (crystal): ################
    feature_irreps_hidden: 32x0o+32x0e+32x1o+32x1e+32x2e+32x2o+32x3o+32x3e+32x4o+32x4e
    irreps_edge_output: 32x0o+32x0e+32x1o+32x1e+32x2e+32x2o+32x3o+32x3e+32x4o+32x4e
    irreps_edge_sh: 0e + 1o + 2e + 3o + 4e
    irreps_node_features: 32x0o+32x0e+32x1o+32x1e+32x2e+32x2o+32x3o+32x3e+32x4o+32x4e
    irreps_node_output: 32x0o+32x0e+32x1o+32x1e+32x2e+32x2o+32x3o+32x3e+32x4o+32x4e
    irreps_triplet_output: 32x0o+32x0e+32x1o+32x1e+32x2e+32x2o+32x3o+32x3e+32x4o+32x4e
    invariant_layers: 2
    invariant_neurons: 64
    num_interaction_layers: 5
    num_radial: 8
    num_spherical: 8
    num_types: 60
    export_triplet: false
    rbf_func: bessel
    set_features: true
    add_edge_tp: false
    irreps_node_prev: 16x0o+16x0e+8x1o+8x1e+8x2e+8x2o+8x3o+8x3e+8x4o+8x4e
    num_node_attr_feas: 64

setup:
  GNN_Net: HamGNN_pre
  accelerator: null
  ignore_warnings: false
  #checkpoint_path: ./network_weights_silicon.ckpt # Path to the model weights file
  load_from_checkpoint: false
  resume: false
  num_gpus: 1 # null: use cpu; [i]: use the ith GPU device
  precision: 32
  property: hamiltonian
  stage: fit # fit: training; test: inference
QuantumLab-ZY commented 1 year ago

Thank you for your feedback. The error message 'The network: HamGNN_pre is not yet supported!' is reported when HamGNN finishes training and attempts to log a certain hyperparameter to tensorboard for visualization, it does not affect the training results of the model nor does it impact the saving of .ckpt model files. I have been ignoring this small bug as it has no impact on the network training, but I will fix it soon. I think your training has probably been completed successfully. You can check if there is a network weight file with the extension .ckpt in the folder train_dir/version_*/checkpoints.

Youhaojen commented 1 year ago

Thank you for your prompt reply. Yes, there is a .ckpt file in train_dir/version_*/checkpoints. Thx