lsj2408 / Transformer-M

[ICLR 2023] One Transformer Can Understand Both 2D & 3D Molecular Data (official implementation)
https://arxiv.org/abs/2210.01765
MIT License
201 stars 24 forks source link

Question about inputting only 2D Data #15

Open AlexWei21 opened 1 year ago

AlexWei21 commented 1 year ago

Hi!

Thank you for introducing such an interesting model to us and sharing the code!

I'm trying to run the model only on 2D structures, would you mind providing a script for using only 2D structures to train the model (Like for PCQM4M-LSC-V2)?

I tried to change the dataset_name and set add_3D to false in the sample train script for 3D data in the readme file, but that doesn't work. I looked into the code and found that in the tasks/graph_prediction.py file , Class GraphPredictionTask, and load_dataset function, when calling BatchedDataDatset, when it set the dataset_version to "2D" for PCQM4M-LSC-V2, it gives the error in criterions/graph_predictions.py line 45: ori_pos = sample['net_input']['batched_data']['pos'], KeyError: 'pos'.

Thank you so much!

sunyrain commented 1 week ago

Same Question

sunyrain commented 1 week ago

Hi!

Thank you for introducing such an interesting model to us and sharing the code!

I'm trying to run the model only on 2D structures, would you mind providing a script for using only 2D structures to train the model (Like for PCQM4M-LSC-V2)?

I tried to change the dataset_name and set add_3D to false in the sample train script for 3D data in the readme file, but that doesn't work. I looked into the code and found that in the tasks/graph_prediction.py file , Class GraphPredictionTask, and load_dataset function, when calling BatchedDataDatset, when it set the dataset_version to "2D" for PCQM4M-LSC-V2, it gives the error in criterions/graph_predictions.py line 45: ori_pos = sample['net_input']['batched_data']['pos'], KeyError: 'pos'.

Thank you so much!

Have you found the solution?

sunyrain commented 1 week ago

I change the parameter "mode_prob: ${mode_prob} for {2D+3D, 2D, 3D}" to 0,1,0, and it works for me.