FAIR-Chem / fairchem

FAIR Chemistry's library of machine learning methods for chemistry
https://opencatalystproject.org/
Other
879 stars 261 forks source link

Divergent Training Behavior with EquiformerV2 on ANI-2x Dataset with PBC #899

Closed IliasChair closed 2 weeks ago

IliasChair commented 2 weeks ago

Hello community,

I'm using the EquiformerV2 model to train a neural network on LMDB splits I created from the ANI-2x dataset. Since the ANI-2x dataset only contains non-periodic structures, I implemented a workaround to mimic a periodic system by placing each molecule within a "dummy" 50 ų box.

This is because previously, training without PBC led to errors, but recent updates seem to have resolved these [*]. Now, however, I’m observing divergent behaviors between two setups with and without PBC: Training Configurations:

W B Chart 30 10 2024, 18_38_46 W B Chart 30 10 2024, 18_38_17

Observations: Forces Magnitude Error: The error in force predictions drops more steadily with PBC, while it plateaus at a higher level without PBC. Cosine Similarity: The PBC setup improves to around 0.8 in cosine similarity, whereas the non-PBC setup remains around 0.2, with little further improvement.

Hypothesis: One possible explanation is that, because I created the LMDB files using the a2g methods with this dummy periodic system, the dataset has been effectively "baked" with periodicity in mind. This might mean that it’s no longer possible to treat it as a truly aperiodic system, since the initial graph construction assumed periodicity.

Another factor could be the way PBC artificially expands the perceived system size: by appending copies of the original box to the sides, each box remains isolated from other molecules, which could make errors appear smaller relative to the total system size.

Thank you for any suggestions or ideas on interpreting these results.

[*] Edit: The errors stemmed from the AtomsToGraphs class, not from the EquiformerV2 code itself. It is not possible to convert non periodic ASE Atoms objects to graphs using that class.

rayg1234 commented 2 weeks ago

Hi @IliasChair Thanks for flagging this, could you share your training configs? we'll try to debug this

IliasChair commented 2 weeks ago

Hello Ray, Thank you so much for looking into this! Although I’m using EqV2 with DeNS, I don’t believe this is the source of the issue. If I saw correctly the eqv2_dens branch was merged into Main just today, so this shouldn't provide any issues. That said, please let me know if you need any more details.

Unfortunately, I can’t share the dataset itself due to its size, but the effects should be reproducible with any dataset prepared as I described. If you’re unable to reproduce the issue, I can likely segment the database and send a portion.

See below for parameters:

trainer: equiformerv2_dens

dataset:
  train:
    format: lmdb
    src: /T1x_ANI-2x/LMDBS/train.lmdb
    key_mapping:
      y: energy
      force: forces
    transforms:
      normalizer:
        energy:
          mean: -444.2186136026055
          stdev: 252.48679232685024
        forces:
          mean: 0.0
          stdev: 0.051076667893874154
  val:
    src: /T1x_ANI-2x/LMDBS/validation.lmdb
  test:
    src: /T1x_ANI-2x/LMDBS/test_ts.lmdb

logger:
 name: wandb
 project: rundir_final

outputs:
  energy:
    property: energy
    shape: 1
    level: system
  forces:
    property: forces
    irrep_dim: 1
    level: atom
    train_on_free_atoms: True
    eval_on_free_atoms: True

loss_functions:
  - energy:
      fn: mae
      coefficient: 2
  - forces:
      fn: l2mae
      coefficient: 100

evaluation_metrics:
  metrics:
    energy:
      - mae
    forces:
      - mae
      - cosine_similarity
      - magnitude_error
    misc:
      - energy_forces_within_threshold
  primary_metric: forces_mae

hide_eval_progressbar: False

model:
  name: hydra
  backbone:
    model: equiformer_v2_dens_backbone

    use_pbc: True             # <-- this needs to be set to 'false' off in the second case
    regress_forces: True
    otf_graph: True

    enforce_max_neighbors_strictly: True   # <-- this needs to be set to 'false' off in the second case

    max_neighbors: 30
    max_radius:               12
    #max_num_elements:         90

    num_layers: 16
    sphere_channels: 128
    num_heads: 8
    attn_alpha_channels: 64
    attn_value_channels: 16
    attn_hidden_channels: 64
    ffn_hidden_channels: 128
    norm_type: 'layer_norm_sh'

    lmax_list: [6]
    mmax_list: [2]
    grid_resolution:          18 

    num_sphere_samples: 128

    edge_channels:              128
    use_atom_edge_embedding:    True
    share_atom_edge_embedding:  False
    use_m_share_rad:            False
    distance_function: "gaussian"
    num_distance_basis:        512         # not used

    attn_activation:          'silu'
    # turn use_s2_act_attn off!
    use_s2_act_attn:          False      
    use_attn_renorm:          True     
    ffn_activation:           'silu'      
    use_gate_act:             False      
    use_grid_mlp:             True       
    use_sep_s2_act:           True      

    alpha_drop:               0.1     
    drop_path_rate:           0.1      
    proj_drop:                0.0

    weight_init:              'uniform'

    use_force_encoding:                   True
    use_noise_schedule_sigma_encoding:    False
    #basis_width_scalar: 2.0

  # Denoising heads:
  heads:
    energy:
      module: equiformer_v2_dens_energy_head
    forces:
      module: equiformer_v2_dens_force_head
    # noise:
    #   module: dens_rank2_symmetric_head

optim:
  batch_size:                   90     
  eval_batch_size:              90      
  load_balancing: False        # had to turn this off because of issue  #753
  load_balancing_on_error:      warn_and_no_balance
  num_workers: 8
  lr_initial:                   0.0095   

  optimizer: AdamW
  optimizer_params:
    weight_decay: 0.001
  scheduler: LambdaLR
  scheduler_params:
    lambda_type: cosine
    warmup_factor: 0.2
    warmup_epochs: 0.001
    lr_min_factor: 0.01         #

  max_epochs: 3
  clip_grad_norm: 50
  ema_decay: 0.999

  eval_every: 14000

  # for denoising positions
  use_denoising_pos:            True
  denoising_pos_params:
    prob:                       0.25      
    fixed_noise_std:            True
    std:                        0.1
    num_steps:                  50
    std_low:                    0.02
    std_high:                   0.4
    corrupt_ratio:              0.25
  denoising_pos_coefficient:    15
rayg1234 commented 2 weeks ago

@IliasChair , there are several potential issues we found with this setup on molecules

1) setting use_pbc=True require the atoms to fit inside the unit cell, if you have atoms that run over (for example if you create a box and have the atoms on the corner of the cell), then setting pbc will create a ton of duplicate edges. The model can still work in this setting but we're not sure about the behavior in this case. we will introduce an assertion for this to make sure the atoms are always inside the cell.

2) setting use_pbc=False, for the ANI-2x, you might need to increase max_neighbors to capture the entire system

3) enforce_max_neighbors_strictly doesnt do anything in use_pbc=False, in use_pbc=True, it will only ensure that the edges are selected deterministically.

You should also verify that the graph generated and being fed into eqv2 network is the one you expect but checking it at the output of graph gen here: https://github.com/FAIR-Chem/fairchem/blob/main/src/fairchem/core/models/equiformer_v2/equiformer_v2.py#L403

let us know if this helps!

IliasChair commented 2 weeks ago

Hello, thanks for looking into this. Your first point seems to explain what we are seeing. ASE Atoms objects always start the box at (0, 0, 0), but since the atomic coordinates in the ANI-2x dataset are centered around this origin, turning PBC=False likely causes some atoms to get cut off. Which could explain the high error. With PBC=True this still leads to issues because not all atoms are within the central cell, leading to duplicate edges, as you mentioned. I’ll rework the dataset to ensure all atoms are shifted into the unit cell and see if it helps.

Regarding issue #900, since this is a persistent issue, would you prefer that I leave it open until there’s a permanent fix?

I will rework my dataset as mentioned and close this issue for now. I’ll also leave a comment with my findings in case it’s useful for others.

Best, Ilias

rayg1234 commented 2 weeks ago

Regarding issue #900, since this is a persistent issue, would you prefer that I leave it open until there’s a permanent fix?

Feel free to leave this open for now, thanks again and we really appreciate you reporting these issues!