Open Hongyu-yu opened 1 month ago
Nice work. A couple of questions and comments.
Do you have mace results that you can share? The improvement in validation error for Allegro is impressive, but Allegro is generally less accurate than mace. So I'd like to see how the improvements translate.
You should really have the y axis of your bar charts start at zero. This kinds of thing is in the first chapter of "How to Lie with Statistics" https://g.co/kgs/RGfduso
@gabor1 Thanks for the quick feedback! For the results, I have recently attended an MLIP competition and tried MACE and MACE-KAN on perovskite (t1). MACE ends with 1.0 meV/atom and 7.5 meV/A while MACE-KAN ends with 0.9 meV/atom and 5.2 meV/A, which we see an obvious better result for forces (same for other datasets in the competition). For more details and official benchmark, we are still running and will make it public as soon as possible. And thanks for reminding us of the statistical errors and plot confusion! We will add random number testing and give more convincing results with plots that have y starting from zero. Thanks again for your valuable feedback!
Dear @Hongyu-yu, thank you very much for you PR and nice work. Could you please make the PR as small as possible, and break down any changes that are not KAN readout related into another PR. The best would be that only the blocks.py, run_train.py, argparser and model.py would be changed + an additional test.
Just an extra note, I am very skeptical about KANs in general, I am bit surprised that it leads to any real improvement. So I would wait to see more results on MACE to merge that.
@ilyes319 Thanks for the quick feedback! Actually the changes in this PR are all related to KAN while some of them are just transforming torch.save/load with dill which is needed to deploy KAN. The core changes are in mace/modules/models.py and mace/modules/blocks.py. As for the additional test, I will add one tomorrow ASAP. For the improvement, a systematic check and experiments on benchmarks are surely needed and we are working on them and I agree with you. By far, our results show that KAN gives better results in the competition above but not tested on benchmark yet. We open-sourced the code now to meet the competition requirements. We will update benchmark results ASAP and report them here. Maybe then we can come to a determined conclusion about whether mulKAN works. From the view of spherical basis, in the last output layer, it's actually a mixture of basis to a scalar energy. MultiKAN (KAN 2.0) may provide a more complex combination than MLP between the basis given by MACE interaction part, which could result in better accuracy. This could be the reason why KAN works better in decoding the latent basis features. Welcome to see your opinion!
Thanks, how are the changes to torch.save related to KAN?
torch.save directly not work for KAN with error
INFO: Saving model to checkpoints/base_run-123.model
Traceback (most recent call last):
File "/public/home/yuhongyu/anaconda3/envs/ace/bin/mace_run_train", line 8, in <module>
sys.exit(main())
File "/public/home/yuhongyu/anaconda3/envs/ace/lib/python3.9/site-packages/mace/cli/run_train.py", line 63, in main
run(args)
File "/public/home/yuhongyu/anaconda3/envs/ace/lib/python3.9/site-packages/mace/cli/run_train.py", line 734, in run
torch.save(model, model_path)
File "/public/home/yuhongyu/anaconda3/envs/ace/lib/python3.9/site-packages/torch/serialization.py", line 379, in save
_save(obj, opened_zipfile, pickle_module, pickle_protocol)
File "/public/home/yuhongyu/anaconda3/envs/ace/lib/python3.9/site-packages/torch/serialization.py", line 589, in _save
pickler.dump(obj)
AttributeError: Can't pickle local object 'Symbolic_KANLayer.__init__.<locals>.<listcomp>.<listcomp>.<lambda>'
But works with torch.save(model, model_path, pickle_module=dill)
MACE + KAN
With additional KAN readout for MACE, more complex combination of spherical basis emerges with a more accurate MACE model and even makes MACE more explainable. Tests and benchmark results will be updated at https://arxiv.org/abs/2409.03430v1.
dill
is used for torch.save/loadpykan
is used for multikanUsage: add --KAN_readout in command line like
mace_run_train --KAN_readout ...
Hope this pull could provide a more accurate MACE model to the community! If it helps and is used, please consider to citehttps://arxiv.org/abs/2409.03430v1
andhttp://arxiv.org/abs/2408.10205
.Hongyu Yu