ACEsuit / mace

MACE - Fast and accurate machine learning interatomic potentials with higher order equivariant message passing.
Other
554 stars 205 forks source link

Add KAN readout options for MACE with possible better accuracy #655

Open Hongyu-yu opened 1 month ago

Hongyu-yu commented 1 month ago

MACE + KAN

With additional KAN readout for MACE, more complex combination of spherical basis emerges with a more accurate MACE model and even makes MACE more explainable. Tests and benchmark results will be updated at https://arxiv.org/abs/2409.03430v1.

dill is used for torch.save/load pykan is used for multikan

Usage: add --KAN_readout in command line like mace_run_train --KAN_readout ... Hope this pull could provide a more accurate MACE model to the community! If it helps and is used, please consider to cite https://arxiv.org/abs/2409.03430v1 and http://arxiv.org/abs/2408.10205.

Hongyu Yu

gabor1 commented 1 month ago

Nice work. A couple of questions and comments.

Do you have mace results that you can share? The improvement in validation error for Allegro is impressive, but Allegro is generally less accurate than mace. So I'd like to see how the improvements translate.

You should really have the y axis of your bar charts start at zero. This kinds of thing is in the first chapter of "How to Lie with Statistics" https://g.co/kgs/RGfduso

Hongyu-yu commented 1 month ago

@gabor1 Thanks for the quick feedback! For the results, I have recently attended an MLIP competition and tried MACE and MACE-KAN on perovskite (t1). MACE ends with 1.0 meV/atom and 7.5 meV/A while MACE-KAN ends with 0.9 meV/atom and 5.2 meV/A, which we see an obvious better result for forces (same for other datasets in the competition). For more details and official benchmark, we are still running and will make it public as soon as possible. And thanks for reminding us of the statistical errors and plot confusion! We will add random number testing and give more convincing results with plots that have y starting from zero. Thanks again for your valuable feedback!

ilyes319 commented 1 month ago

Dear @Hongyu-yu, thank you very much for you PR and nice work. Could you please make the PR as small as possible, and break down any changes that are not KAN readout related into another PR. The best would be that only the blocks.py, run_train.py, argparser and model.py would be changed + an additional test.

Just an extra note, I am very skeptical about KANs in general, I am bit surprised that it leads to any real improvement. So I would wait to see more results on MACE to merge that.

Hongyu-yu commented 1 month ago

@ilyes319 Thanks for the quick feedback! Actually the changes in this PR are all related to KAN while some of them are just transforming torch.save/load with dill which is needed to deploy KAN. The core changes are in mace/modules/models.py and mace/modules/blocks.py. As for the additional test, I will add one tomorrow ASAP. For the improvement, a systematic check and experiments on benchmarks are surely needed and we are working on them and I agree with you. By far, our results show that KAN gives better results in the competition above but not tested on benchmark yet. We open-sourced the code now to meet the competition requirements. We will update benchmark results ASAP and report them here. Maybe then we can come to a determined conclusion about whether mulKAN works. From the view of spherical basis, in the last output layer, it's actually a mixture of basis to a scalar energy. MultiKAN (KAN 2.0) may provide a more complex combination than MLP between the basis given by MACE interaction part, which could result in better accuracy. This could be the reason why KAN works better in decoding the latent basis features. Welcome to see your opinion!

ilyes319 commented 1 month ago

Thanks, how are the changes to torch.save related to KAN?

Hongyu-yu commented 1 month ago

torch.save directly not work for KAN with error

INFO: Saving model to checkpoints/base_run-123.model
Traceback (most recent call last):
  File "/public/home/yuhongyu/anaconda3/envs/ace/bin/mace_run_train", line 8, in <module>
    sys.exit(main())
  File "/public/home/yuhongyu/anaconda3/envs/ace/lib/python3.9/site-packages/mace/cli/run_train.py", line 63, in main
    run(args)
  File "/public/home/yuhongyu/anaconda3/envs/ace/lib/python3.9/site-packages/mace/cli/run_train.py", line 734, in run
    torch.save(model, model_path)
  File "/public/home/yuhongyu/anaconda3/envs/ace/lib/python3.9/site-packages/torch/serialization.py", line 379, in save
    _save(obj, opened_zipfile, pickle_module, pickle_protocol)
  File "/public/home/yuhongyu/anaconda3/envs/ace/lib/python3.9/site-packages/torch/serialization.py", line 589, in _save
    pickler.dump(obj)
AttributeError: Can't pickle local object 'Symbolic_KANLayer.__init__.<locals>.<listcomp>.<listcomp>.<lambda>'

But works with torch.save(model, model_path, pickle_module=dill)