MDIL-SNU / SevenNet

SevenNet - a graph neural network interatomic potential package supporting efficient multi-GPU parallel molecular dynamics simulations.
https://pubs.acs.org/doi/10.1021/acs.jctc.4c00190
GNU General Public License v3.0
116 stars 13 forks source link

mismatch in tensor size #102

Open thangckt opened 3 days ago

thangckt commented 3 days ago

Dear Deverlopers,

I get the below error when set is_train_stress: True

path/python3.11/site-packages/torch/nn/modules/loss.py:535: UserWarning: Using a target size (torch.Size([2, 6])) that is different to the input size (torch.Size([12])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.
  return F.mse_loss(input, target, reduction=self.reduction)
Traceback (most recent call last):
  File "/home1/p001cao/app/miniconda3/envs/py11mace/bin/sevenn", line 8, in <module>
    sys.exit(main())
             ^^^^^^
  File "path/python3.11/site-packages/sevenn/main/sevenn.py", line 105, in main
    train(global_config, working_dir)
  File "path/python3.11/site-packages/sevenn/scripts/train.py", line 85, in train
    processing_epoch(
  File "path/python3.11/site-packages/sevenn/scripts/processing_epoch.py", line 50, in processing_epoch
    trainer.run_one_epoch(
  File "path/python3.11/site-packages/sevenn/train/trainer.py", line 65, in run_one_epoch
    error_recorder.update(output)
  File "path/python3.11/site-packages/sevenn/error_recorder.py", line 271, in update
    self._update(output)
  File "path/python3.11/site-packages/sevenn/error_recorder.py", line 266, in _update
    metric.update(output)
  File "path/python3.11/site-packages/sevenn/error_recorder.py", line 150, in update
    se = self._square_error(y_ref, y_pred, self.vdim)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "path/python3.11/site-packages/sevenn/error_recorder.py", line 146, in _square_error
    return self._se(y_ref, y_pred).view(-1, vdim).sum(dim=1)
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "path/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "path/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "path/python3.11/site-packages/torch/nn/modules/loss.py", line 535, in forward
    return F.mse_loss(input, target, reduction=self.reduction)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "path/python3.11/site-packages/torch/nn/functional.py", line 3365, in mse_loss
    expanded_input, expanded_target = torch.broadcast_tensors(input, target)
                                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "path/python3.11/site-packages/torch/functional.py", line 76, in broadcast_tensors
    return _VF.broadcast_tensors(tensors)  # type: ignore[attr-defined]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: The size of tensor a (12) must match the size of tensor b (6) at non-singleton dimension 1

This error disapears when set is_train_stress: False

Can you have a little help? Thanks.

YutackPark commented 3 days ago

It is likely to be the consequence of sevenn.train.dataload::atoms_to_graph. The routine tries to load stress from the atoms object but does not ensure whether the shape and types of loaded stress are correct.

Could you share the version and minimal data to reproduce the error?

thangckt commented 2 days ago

hi @YutackPark I attach few frames of extxyz data. Please use this

data_format_args:                            
    energy_key: 'ref_energy'                 
    force_key: 'ref_forces'                   
    stress_key: 'ref_stress' 

data.txt

YutackPark commented 2 days ago

atoms.info['y_stress'] = atoms.info[stress_key]

When 'stress' data is loaded via custom key, in your case ref_stress, current code does not check whether it has (1, 6) shapes. 7net has to ensure the stress to have (1, 6) shape. This is a bug and will be patched.

For now, you may preprocess your ref_stress to have (1, 6) shape to avoid the problem. Sorry for the inconvenience. To prevent this kind of bugs, I'm writing pytest codes to have best practice... but not merged yet.

thangckt commented 1 day ago

hi @YutackPark Thank you for your explain.

you may preprocess your ref_stress to have (1, 6) shape to avoid the problem

ref_stress already has shape (1,6) what preprocess did you mean?

YutackPark commented 1 day ago

@thangckt, I copy pasted the data you gave to me and read it using ase.io.read and it gives (6,) shaped array. Maybe ASE automatically converts it to have plain (6,) shape before writing a file. Seems only available bypass is storing the results using SinglePointCalculator as I mentioned in previous issue #61.

Plus, I looked more closely and seems the feature is broken, due to the difference between stress notation inside the SevenNet -1 * (xx, yy, zz, xy, yz, zx) and ASE (Voigt: xx, yy, zz, yz, zx, xy). I recommend to not use it before the patch, unless you're very confident of it.

thangckt commented 1 day ago

hi @YutackPark Thank you for your guide.

About the stress component order, it will be serious misleading. Can I know why you don't follow the well-known Voigt notation? image

YutackPark commented 1 day ago

hi @thangckt It is another side-effect of following our groups's previous MLIP package SIMPLE-NN, or VASP itself. VASP uses xx, yy, zz, xy, yz, zx notation in its OUTCAR file, and so does SIMPLE-NN. I'm trying my best to hide this cumbersomeness to users, but it failed in this case. I may refactor this to follow Voigt notation ALWAYS after stabilizing the code.

thangckt commented 1 day ago

hi @YutackPark Thank you so much for your information. I will follow your updates.

About stress notation, I think It should better follow Voigt notation. Any output from a specific software should be converted to this convention. Otherwise, it will be serious misleading for users.

YutackPark commented 1 day ago

It was supposed to follow voigt notation for this recently introduced EFS key feature, and what happened here is simply my fault, and I agree with you. Anyway, thanks for the bug report. I'll notify you with closing the issue after the fix. The mixed notation inside the code is really confusing (even for me).