torchmd / torchmd-net

Training neural network potentials
MIT License
335 stars 74 forks source link

How to Simulate Large System using TensorNet + OpenMM-Torch? #347

Open kei0822kei opened 2 weeks ago

kei0822kei commented 2 weeks ago

Hi,

Thank you for maintaining great package. I want to simulate relatively larger system (~10000 atoms) using tensornet.

After I finished to train model using TensorNet-SPICE.yaml, I tried to apply this model to MD simulation for larger system using openmm-torch. When I simulated using the system composed of ~4000 atoms, 80 GiB of GPU memory has filled out. I found out when calculating force (backpropagation phase) consumed most of the GPU memory and resulted in out of memory.

Is there possible way to avoid this?

I expect calculating atomic energy using 'reporesentaion_model' (TensorNet) can be splitted into batch, and can be avoided using large GPU memory. Is it possible?

guillemsimeon commented 2 weeks ago

Hi,

Thanks for using TensorNet. I suspect the problem comes from using a cutoff of 10A. The model corresponding to the SPICE yaml was never used to run MD in such large systems. It was set to 10A because of the presence of dimers. However, I think in the literature people have trained other models on SPICE with smaller cutoffs without problem. I would try 5A. Also, make sure that static_shapes argument is False.

As a side note, a smaller and shallower model compared the one you get with the yaml might be enough for your purposes. Take into account also that the hyperparameters for SPICE have not been fully optimized. It depends on the accuracy/efficiency tradeoff, and your system.

Also, you can try to initialize a model (create_model) with your desired config, even if it is not trained, and try to deploy it in openmm to perform a simulation step and see if it fits in memory, before training it.

Guillem

On Fri, 8 Nov 2024 at 00:52, kei0822kei @.***> wrote:

Hi,

Thank you for maintaining great package. I want to simulate relatively larger system (~10000 atoms) using tensornet.

After I finished to train model using TensorNet-SPICE.yaml https://github.com/torchmd/torchmd-net/blob/main/examples/TensorNet-SPICE.yaml, I tried to apply this model to MD simulation for larger system using openmm-torch. When I simulated using the system composed of ~4000 atoms, 80 GiB of GPU memory has filled out. I found out when calculating force (backpropagation phase) consumed most of the GPU memory and resulted in out of memory.

Is there possible way to avoid this?

I expect calculating atomic energy using 'reporesentaion_model' (TensorNet) can be splitted into batch, and can be avoided using large GPU memory. Is it possible?

— Reply to this email directly, view it on GitHub https://github.com/torchmd/torchmd-net/issues/347, or unsubscribe https://github.com/notifications/unsubscribe-auth/ANJMOAY3EGM3ZFWDRDVQ5A3Z7P4NHAVCNFSM6AAAAABRMKWLLSVHI2DSMVQWIX3LMV43ASLTON2WKOZSGY2DENBXGYYDAOA . You are receiving this because you are subscribed to this thread.Message ID: @.***>

kei0822kei commented 2 weeks ago

Thank you for your replying and sorry for my lacking description.

Actually, I revised yaml file from original one and I used cutoff_upper=5.0. My hparams.yaml is as follows. (I wrote ASE dataset class in order to use the newest SPICE dataset. Please ignore these settings.)

load_model: null
conf: null
num_epochs: 100000
batch_size: 64
inference_batch_size: 64
lr: 0.0001
lr_patience: 15
lr_metric: val
lr_min: 1.0e-07
lr_factor: 0.8
lr_warmup_steps: 1000
early_stopping_patience: 30
reset_trainer: false
weight_decay: 0.0
ema_alpha_y: 1.0
ema_alpha_neg_dy: 1.0
ngpus: -1
num_nodes: 1
precision: 32
log_dir: .
splits: null
train_size: null
val_size: 0.05
test_size: 0.1
test_interval: 10
save_interval: 10
seed: 42
num_workers: 48
redirect: true
gradient_clipping: 40
remove_ref_energy: false
dataset: ASE
dataset_root: /data/spice/2.0.1/spice_with_charge
dataset_arg:
  periodic: false
  energy_key: formation_energy
  forces_key: forces
  partial_charges_key: charges
coord_files: null
embed_files: null
energy_files: null
force_files: null
dataset_preload_limit: 1024
y_weight: 0.05
neg_dy_weight: 0.95
train_loss: mse_loss
train_loss_arg: null
model: tensornet
output_model: Scalar
output_mlp_num_layers: 0
prior_model:
  ZBL:
    cutoff_distance: 3
    max_num_neighbors: 5
charge: false
spin: false
embedding_dimension: 256
num_layers: 6
num_rbf: 64
activation: silu
rbf_type: expnorm
trainable_rbf: false
neighbor_embedding: false
aggr: add
distance_influence: both
attn_activation: silu
num_heads: 8
vector_cutoff: false
equivariance_invariance_group: O(3)
box_vecs: null
static_shapes: false
check_errors: true
derivative: true
cutoff_lower: 0.0
cutoff_upper: 5.0
atom_filter: -1
max_z: 100
max_num_neighbors: 64
standardize: false
reduce_op: add
wandb_use: false
wandb_name: training
wandb_project: training_
wandb_resume_from_id: null
tensorboard_use: true
prior_args:
- cutoff_distance: 3
  max_num_neighbors: 5
  atomic_number:
  - 0
  - 1
  - 2
  - 3
  - 4
  - 5
  - 6
  - 7
  - 8
  - 9
  - 10
  - 11
  - 12
  - 13
  - 14
  - 15
  - 16
  - 17
  - 18
  - 19
  - 20
  - 21
  - 22
  - 23
  - 24
  - 25
  - 26
  - 27
  - 28
  - 29
  - 30
  - 31
  - 32
  - 33
  - 34
  - 35
  - 36
  - 37
  - 38
  - 39
  - 40
  - 41
  - 42
  - 43
  - 44
  - 45
  - 46
  - 47
  - 48
  - 49
  - 50
  - 51
  - 52
  - 53
  - 54
  - 55
  - 56
  - 57
  - 58
  - 59
  - 60
  - 61
  - 62
  - 63
  - 64
  - 65
  - 66
  - 67
  - 68
  - 69
  - 70
  - 71
  - 72
  - 73
  - 74
  - 75
  - 76
  - 77
  - 78
  - 79
  - 80
  - 81
  - 82
  - 83
  - 84
  - 85
  - 86
  - 87
  - 88
  - 89
  - 90
  - 91
  - 92
  - 93
  - 94
  - 95
  - 96
  - 97
  - 98
  - 99
  distance_scale: 1.0e-10
  energy_scale: 1.60218e-19

As you adviced me, I should check accuracy/efficiency tradeoff, especially the following settings.

embedding_dimension: 256
num_layers: 6
num_rbf: 64

If GPU memory lacking still occurs after parameter optimization, I am going to try to split into batch when calculating atomic energies.

If we do not take into account about long range interaction, system energy can be written as the sum of atomic energies, $E = \sum E_i$ and $E_i$ can be calculated using its neighbor atomic information, and also its force $\boldsymbol{F}_i$ can be calculated.

Therefore, I expect I can split into batch when calculating atomic energies and control GPU memory usage.

Do you think it is possible or not?

guillemsimeon commented 2 weeks ago

6 layers is crazy. 256 hidden dimensions is too much. It is normal it does not fit. The largest model I have ever used is 256 hidden and 3 layers to achieve SOTA on QM9, but in all other cases I used 128 hidden and 2 layers (1 layer is also good). I don’t expect the model to work much better with 6 layers than 2 (I would expect the opposite in fact). 6 layers has never been tested, and the simulation/training will be extremely slow.

Regarding the splitting, I am not sure if I understood. But take into account that for energy prediction and forces, the output corresponding to some atom depends on all other atoms found within a distance of (L+1)*cutoff, where L is the number of layers you put in the yaml file. So to get a meaningful atomic energy contribution you need that atom and all the other ones. Also, take into account that the force on an atom is the gradient of the total energy wrt its position, so you would need to accumulate and sum the gradients obtained of all the Uis that depend on that atom. It is possible, but not useful.

On Fri, 8 Nov 2024 at 01:57, kei0822kei @.***> wrote:

Thank you for your replying and sorry for my lacking description.

Actually, I revised yaml file from original one and I used cutoff_upper=5.0. My hparams.yaml is as follows. (I wrote ASE dataset class in order to use the newest SPICE dataset. Please ignore these settings.)

load_model: null conf: null num_epochs: 100000 batch_size: 64 inference_batch_size: 64 lr: 0.0001 lr_patience: 15 lr_metric: val lr_min: 1.0e-07 lr_factor: 0.8 lr_warmup_steps: 1000 early_stopping_patience: 30 reset_trainer: false weight_decay: 0.0 ema_alpha_y: 1.0 ema_alpha_neg_dy: 1.0 ngpus: -1 num_nodes: 1 precision: 32 log_dir: . splits: null train_size: null val_size: 0.05 test_size: 0.1 test_interval: 10 save_interval: 10 seed: 42 num_workers: 48 redirect: true gradient_clipping: 40 remove_ref_energy: false dataset: ASE dataset_root: /data/spice/2.0.1/spice_with_charge dataset_arg: periodic: false energy_key: formation_energy forces_key: forces partial_charges_key: charges coord_files: null embed_files: null energy_files: null force_files: null dataset_preload_limit: 1024 y_weight: 0.05 neg_dy_weight: 0.95 train_loss: mse_loss train_loss_arg: null model: tensornet output_model: Scalar output_mlp_num_layers: 0 prior_model: ZBL: cutoff_distance: 3 max_num_neighbors: 5 charge: false spin: false embedding_dimension: 256 num_layers: 6 num_rbf: 64 activation: silu rbf_type: expnorm trainable_rbf: false neighbor_embedding: false aggr: add distance_influence: both attn_activation: silu num_heads: 8 vector_cutoff: false equivariance_invariance_group: O(3) box_vecs: null static_shapes: false check_errors: true derivative: true cutoff_lower: 0.0 cutoff_upper: 5.0 atom_filter: -1 max_z: 100 max_num_neighbors: 64 standardize: false reduce_op: add wandb_use: false wandb_name: training wandbproject: training wandb_resume_from_id: null tensorboard_use: true prior_args:

  • cutoff_distance: 3 max_num_neighbors: 5 atomic_number:
    • 0
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99 distance_scale: 1.0e-10 energy_scale: 1.60218e-19

As you adviced me, I should check accuracy/efficiency tradeoff, especially the following settings.

embedding_dimension: 256 num_layers: 6 num_rbf: 64

If GPU memory lacking still occurs after parameter optimization, I am going to try to split into batch when calculating atomic energies.

If we do not take into account about long range interaction, system energy can be written as the sum of atomic energies, $E = \sum E_i$ and $E_i$ can be calculated using its neighbor atomic information, and also its force $\boldsymbol{F}_i$ can be calculated.

Therefore, I expect I can split into batch when calculating atomic energies and control GPU memory usage.

Do you think it is possible or not?

— Reply to this email directly, view it on GitHub https://github.com/torchmd/torchmd-net/issues/347#issuecomment-2463538171, or unsubscribe https://github.com/notifications/unsubscribe-auth/ANJMOAYTXHOYBIRI7I5MHYTZ7QD6XAVCNFSM6AAAAABRMKWLLSVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDINRTGUZTQMJXGE . You are receiving this because you commented.Message ID: @.***>

kei0822kei commented 2 weeks ago

I understood my model is terribly large. I should have read the paper more carefully. Thank you so much!