Closed PAIplus closed 7 months ago
the test mae for Shear modulus is 20.2761680843171 and the test mae for Bulk modulus is 59.41012252656236. I wonder whether it is due to the code
mean_absolute_error(np.array(targets), np.array(predictions)) * std_train
?
A "std_train" is timed to the final result.
Please kindly check the data loading logic in the code base, and let me know if there are still remaining issues~
Basically, bulk and shear used a little bit different data loading process and may caused your mentioned problem, I felt like you can solve it by looking into the code, but don’t hesitate to contact me if it is not solved.
The * std train is not the cause.
Specifically, when you want to rerun the experiments for bulk and shear, you need to change the property name in the train_mp.py to corresponding properties listed in train_mp.py, and also add mp_id_list=“bulk” or “shear” after save_dataloader=False in train_mp.py .
For example, train_prop_model(learning_rate=0.001,name="matformer", dataset="megnet", prop=props[2], pyg_input=True, n_epochs=500, batch_size=64, use_lattice=True, output_dir="./matformer_mp_bulk", use_angle=False, save_dataloader=False, mp_id_list=“bulk”)
You can change the model setting and learning rate etc accordingly as shown in the paper. Please let me know if your problem is not solved.
Thank you for your interest in this Matformer code, I believe your proposed issue is valuable and informative to future users of this repo.
I will also update the readme file accordingly once solved.
Do you still have problems, @PAIplus ? If not, I will close this issue as solved.
Thanks a lot for your help. I have figured out the problem.
I am trying to reproduce the results in your paper. Unfortunately, with the learning rate and epochs in your paper, I can not reproduce the results for Bulk modulus and Shear modulus. Could you please specify how to reproduce the results? I just followed your descriptions: download the data and put it under the folder "data", and ran the code with the script you provide in the repo.