mir-group / nequip

NequIP is a code for building E(3)-equivariant interatomic potentials
https://www.nature.com/articles/s41467-022-29939-5
MIT License
565 stars 124 forks source link

πŸ› [BUG] Batchsize Problem in PerAtomMSELoss #332

Closed QuantumMisaka closed 1 year ago

QuantumMisaka commented 1 year ago

Describe the bug when I use NequIP to training and validation in a Fe-C-H-O dataset (traindata frame number is 12135 while test data frame number is 1000) with batchsize for training is 5 and batchsize for validation is 5, error will occur:

Traceback (most recent call last):
  File "/home/liuzq/apps/anaconda3/envs/nequip2/bin/nequip-train", line 8, in <module>
    sys.exit(main())
  File "/home/liuzq/apps/anaconda3/envs/nequip2/lib/python3.9/site-packages/nequip/scripts/train.py", line 78, in main
    trainer.train()
  File "/home/liuzq/apps/anaconda3/envs/nequip2/lib/python3.9/site-packages/nequip/train/trainer.py", line 778, in train
    self.epoch_step()
  File "/home/liuzq/apps/anaconda3/envs/nequip2/lib/python3.9/site-packages/nequip/train/trainer.py", line 916, in epoch_step
    self.batch_step(
  File "/home/liuzq/apps/anaconda3/envs/nequip2/lib/python3.9/site-packages/nequip/train/trainer.py", line 855, in batch_step
    loss, loss_contrib = self.loss(pred=scaled_out, ref=_data_unscaled)
  File "/home/liuzq/apps/anaconda3/envs/nequip2/lib/python3.9/site-packages/nequip/train/loss.py", line 104, in __call__
    _loss = self.funcs[key](
  File "/home/liuzq/apps/anaconda3/envs/nequip2/lib/python3.9/site-packages/nequip/train/_loss.py", line 90, in __call__
    loss = loss / N
RuntimeError: The size of tensor a (461) must match the size of tensor b (5) at non-singleton dimension 0

I have tried lots of different batchsize for training and validation, which all failed unless I set 1 for batchsize for training and validation. but batchsize=1 is so unefficient.

To Reproduce This is my yaml file:

root: FeCHO_valid
run_name: FeCHO_run
workdir: FeCHO_workdir
seed: 42                                                  
# dataset_seed: 456    
append: true            
default_dtype: float32     
allow_tf32: false               
device:  cuda                   

# == network ==
model_builders:
 - SimpleIrrepsConfig   
 - EnergyModel               
 - PerSpeciesRescale       
 - ForceOutput           
 - RescaleEnergyEtc        

r_max: 6.0            
num_layers: 6    

l_max: 2              
parity: true            
num_features: 32     

nonlinearity_type: gate        
resnet: false                 

nonlinearity_scalars:
  e: silu
  o: tanh

nonlinearity_gates:
  e: silu
  o: tanh

# radial network basis
num_basis: 8                     
BesselBasis_trainable: true       
PolynomialCutoff_p: 6        

# radial network
invariant_layers: 3              
invariant_neurons: 64             
avg_num_neighbors: auto     
use_sc: true                 
compile_model: false         

# for extxyz file
dataset: ase
dataset_file_name: ./data/traindata.xyz
ase_args:
  format: extxyz
# include_keys:
#   - user_label
# key_mapping:
#   user_label: label0

# A list of chemical species found in the data. The NequIP atom types will be named after the chemical symbols and ordered by atomic number in ascending order.
# (In this case, NequIP's internal atom type 0 will be named H and type 1 will be named C.)
# Atoms in the input will be assigned NequIP atom types according to their atomic numbers.
chemical_symbols:
  - H
  - C
  - O
  - Fe

# If you want to use a different dataset for validation, you can specify
# the same types of options using a `validation_` prefix:
validation_dataset: ase
validation_dataset_file_name: ./data/test_1000.xyz                                            # need to be a format accepted by ase.io.read

# logging
wandb: true                        # we recommend using wandb for logging
wandb_project: FeCHO_valid                         # project name used in wandb
wandb_watch: false

# see https://docs.wandb.ai/ref/python/watch
# wandb_watch_kwargs:
#   log: all
#   log_freq: 1
#   log_graph: true

verbose: info         # the same as python logging, e.g. warning, info, debug, error. case insensitive
log_batch_freq: 200    # batch frequency, how often to print training errors withinin the same epoch
log_epoch_freq: 1     # epoch frequency, how often to print 
save_checkpoint_freq: -1    # frequency to save the intermediate checkpoint. no saving of intermediate checkpoints when the value is not positive.
save_ema_checkpoint_freq: -1  # frequency to save the intermediate ema checkpoint. no saving of intermediate checkpoints when the value is not positive.

# training
n_train: 12135                   # number of training data
n_val: 1000                      # number of validation data
learning_rate: 0.005        # learning rate, we found values between 0.01 and 0.005 to work best - this is often one of the most important hyperparameters to tune
batch_size: 5        # batch size, we found it important to keep this small for most applications including forces (1-5); for energy-only training, higher batch sizes work better
validation_batch_size: 5      # batch size for evaluating the model during validation. This does not affect the training results, but using the highest value possible (<=n_val) without running out of memory will speed up your training.
max_epochs: 100000   # stop training after _ number of epochs, we set a very large number here, it won't take this long in practice and we will use early stopping instead
train_val_split: random    # can be random or sequential. if sequential, first n_train elements are training, next n_val are val, else random, usually random is the right choice
shuffle: true         # If true, the data loader will shuffle the data, usually a good idea
metrics_key: validation_loss    # metrics used for scheduling and saving best model. Options: `set`_`quantity`, set can be either "train" or "validation, "quantity" can be loss or anything that appears in the validation batch step header, such as f_mae, f_rmse, e_mae, e_rmse
use_ema: true          # if true, use exponential moving average on weights for val/test, usually helps a lot with training, in particular for energy errors
ema_decay: 0.99              # ema weight, typically set to 0.99 or 0.999
ema_use_num_updates: true     # whether to use number of updates when computing averages
# report_init_validation: true   # if True, report the validation error for just initialized model

# early stopping based on metrics values. 
# LR, wall and any keys printed in the log file can be used. 
# The key can start with Training or validation. If not defined, the validation value will be used.
early_stopping_patiences:    # stop early if a metric value stopped decreasing for n epochs
  validation_loss: 1000

# early_stopping_delta:    # If delta is defined, a decrease smaller than delta will not be considered as a decrease
#   validation_loss: 0.005

early_stopping_cumulative_delta: false   # If True, the minimum value recorded will not be updated when the decrease is smaller than delta

early_stopping_lower_bounds:     # stop early if a metric value is lower than the bound
  LR: 1.0e-5

# early_stopping_upper_bounds:    # stop early if a metric value is higher than the bound
#   cumulative_wall: 1.0e+100

# loss function
loss_coeffs:   # different weights to use in a weighted loss functions
  forces:   # if using PerAtomMSELoss, a default weight of 1:1 on each should work well
    - 1
    - PerAtomMSELoss
  total_energy:                                                                    
    - 1

# output metrics
metrics_components:
  - - total_energy
    - rmse    
  - - total_energy
    - rmse
    - PerAtom: True                      # if true, energy is normalized by the number of atoms
  - - total_energy
    - mae    
  - - total_energy
    - mae
    - PerAtom: True                      # if true, energy is normalized by the number of atoms
  - - forces                               # key 
    - rmse                                  # "rmse" or "mae"
    - PerSpecies: True                     # if true, per species contribution is counted separately
      report_per_component: False          # if true, statistics on each component (i.e. fx, fy, fz) will be counted separately
  - - forces                                
    - mae                                  
    - PerSpecies: True                     
      report_per_component: False    

optimizer_name: Adam    
optimizer_amsgrad: true
optimizer_betas: !!python/tuple
  - 0.9
  - 0.999
optimizer_eps: 1.0e-08
optimizer_weight_decay: 0

max_gradient_norm: null

# lr scheduler, currently only supports the two options listed below, if you need more please file an issue
# first: on-plateau, reduce lr by factory of lr_scheduler_factor if metrics_key hasn't improved for lr_scheduler_patience epoch
lr_scheduler_name: ReduceLROnPlateau
lr_scheduler_patience: 100
lr_scheduler_factor: 0.5

per_species_rescale_scales_trainable: false
# whether the scales are trainable. Defaults to False. Optional
per_species_rescale_shifts_trainable: false
# whether the shifts are trainable. Defaults to False. Optional

PerSpeciesScaleShift_shifts: [0.0, 0.0, 0.0, 0.0]

PerSpeciesScaleShift_scales: [1.0, 1.0, 1.0, 1.0]

Just do nequip train <yamlfile> in a normal nequip environment

Expected behavior Training Process running properly as example in NequIP

Environment (please complete the following information):

Additional context Specific data will be uploaded nequip_1000vaild5.log FeCHO_nequip.1.tar.gz FeCHO_nequip.2.tar.gz

QuantumMisaka commented 1 year ago

I guess the error reason may lines in my dataset, which has different atoms number in each structure frame. But I cannot find a way to load multiple dataset in our yaml input , which can be easily done in deepmd json-format input file.

QuantumMisaka commented 1 year ago

The problem lies in the format of loss function. After I changed the format of loss function:


# loss function
loss_coeffs:   # different weights to use in a weighted loss functions
  forces:   # if using PerAtomMSELoss, a default weight of 1:1 on each should work well
    - 5
    - MSELoss
  total_energy:                                                                    
    - 1
    - MSELoss

batch_size can be any number and no problem emerged.

I think the problem is in the funtion of PerAtomMSELoss

Hongyu-yu commented 1 year ago

You should use

loss_coeffs:   # different weights to use in a weighted loss functions
  forces:   1.0
  total_energy:                                                                    
    - 1.0
    - PerAtomMSELoss

instead of

loss_coeffs:   # different weights to use in a weighted loss functions
  forces:   # if using PerAtomMSELoss, a default weight of 1:1 on each should work well
    - 1
    - PerAtomMSELoss
  total_energy:                                                                    
    - 1

PerAtomMSELoss works for total_energy instead of forces.

Linux-cpp-lisp commented 1 year ago

@Hongyu-yu is entirely right here; thanks for responding to this!

I've added a better error message for this case.