SparkyTruck / deepmd-jax

A lightweight DeepPotentialMD with JAX backend, and more than that! Built for both performance and flexibility in pure Python.
MIT License
17 stars 5 forks source link

Model evaluation of atomic labels #1

Open LuckyBoyYao opened 7 months ago

LuckyBoyYao commented 7 months ago

I have trained a model with atomic as the label, the function get_batch doesn't return the shape of my original data when I use evaluate.py, how can I solve this please?

problem
SparkyTruck commented 7 months ago

Hi LuckyBoyYao,

Can you kindly check if the training and evaluation data has the same atom type index map? If this doesn’t solve the problem and if you’d like share with me your model and data I might be able to help as well.

Best, Ruiqi

On Mon, Apr 22, 2024 at 10:33 PM LuckyBoyYao @.***> wrote:

I have trained a model with atomic as the label, the function get_batch doesn't return the shape of my original data when I use evaluate.py, how can I solve this please? problem.png (view on web) https://github.com/SparkyTruck/deepmd-jax/assets/113525990/7c3d2ad9-545e-4106-85ce-d153e2624d56

— Reply to this email directly, view it on GitHub https://github.com/SparkyTruck/deepmd-jax/issues/1, or unsubscribe https://github.com/notifications/unsubscribe-auth/AHF5WSGMIRVZMCABXL2FYU3Y6XB77AVCNFSM6AAAAABGT43M46VHI2DSMVQWIX3LMV43ASLTON2WKOZSGI2TOOBRGY2TCMA . You are receiving this because you are subscribed to this thread.Message ID: @.***>

LuckyBoyYao commented 7 months ago

Yeah, that's the model evaluation I did with the original training set. My model and data are provided below.

data.zip

SparkyTruck commented 7 months ago

Hi,

Please use data.DPDataset(data_paths, ['coord', 'box', 'atomic_dipole'], {'atomic_sel':[1]}) in eval.py. The dataset needs {'atomic_sel':[1]} to know that dipoles correspond to this atom type. Also, model.wc_predict() gives the absolute coords of wc. To evaluate the error you can directly use model.apply() (same as model.call()) to give the relative position to the atoms.

Sorry for the confusion as I didn't seem to have provided an example for wc evaluation. If you're interested, kindly stay tuned as I'll be upgrading the package to a much simpler interface in the future.

Best, Ruiqi

On Tue, Apr 23, 2024 at 4:43 AM LuckyBoyYao @.***> wrote:

Yeah, that's the model evaluation I did with the original training set. My model and data are provided below.

data.zip https://github.com/SparkyTruck/deepmd-jax/files/15073877/data.zip

— Reply to this email directly, view it on GitHub https://github.com/SparkyTruck/deepmd-jax/issues/1#issuecomment-2071748023, or unsubscribe https://github.com/notifications/unsubscribe-auth/AHF5WSDWP2QOOS3ARBCARPTY6YNJZAVCNFSM6AAAAABGT43M46VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDANZRG42DQMBSGM . You are receiving this because you commented.Message ID: @.***>

LuckyBoyYao commented 7 months ago

Dear Ruiqi,

Thank you very much for your assistance. Looking forward to the updated version. It would be easier to understand the atomic model if you could provide an example evaluation of wc.

Also, may I ask why there is such a big difference in the loss between the training and validation sets? image dipole_parity_t dipole_parity_v

Best, Yao

SparkyTruck commented 7 months ago

If your dataset is correct, it means you're overfitting and should train fewer steps. Reduce network width might help a bit as well, where the default 32, 64 etc. has been well tested. I ran the script you gave me and there does not seem to be an issue. Iter 96000 L 0.00005 Lval 0.00005 Time 2.30s Iter 97000 L 0.00005 Lval 0.00005 Time 2.31s Iter 98000 L 0.00005 Lval 0.00005 Time 2.28s Iter 99000 L 0.00005 Lval 0.00005 Time 2.29s Iter 100000 L 0.00005 Lval 0.00005 Time 2.30s

LuckyBoyYao commented 7 months ago

Yeah, this dataset is regular so it can be trained well, but when I use a perturbed or AIMD dataset it's hard to train well. When I increase the width of the network, I get overfitting, which is what overwhelms me. Should I take some approach on the dataset?

SparkyTruck commented 7 months ago
  1. Make sure your dataset is correct
  2. Increase training set size

On Thu, Apr 25, 2024 at 3:50 AM LuckyBoyYao @.***> wrote:

Yeah, this dataset is regular so it can be trained well, but when I use a perturbed or AIMD dataset it's hard to train well. When I increase the width of the network, I get overfitting, which is what overwhelms me. Should I take some approach on the dataset?

— Reply to this email directly, view it on GitHub https://github.com/SparkyTruck/deepmd-jax/issues/1#issuecomment-2076582728, or unsubscribe https://github.com/notifications/unsubscribe-auth/AHF5WSDSJX6RDREV4ZO5PDLY7CYVPAVCNFSM6AAAAABGT43M46VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDANZWGU4DENZSHA . You are receiving this because you commented.Message ID: @.***>

LuckyBoyYao commented 7 months ago

Greatly appreciate your advice, I will double check my dataset.