WillHua127 / EnzymeFlow

Official repository of EnzymeFlow
https://arxiv.org/abs/2410.00327
Other
60 stars 11 forks source link

(READ IF ISSUES RAISE) int32/int64 inconsistency in all_atom.py one hot #3

Open carbondrop-nick opened 1 month ago

carbondrop-nick commented 1 month ago

\flowmatch\data\all_atom.py line 18: GROUP_IDX = torch.tensor(residue_constants.restype_atom14_to_rigid_group) appears to create int32, which results in error 'group_maskRuntimeError: one_hot is only applicable to index tensor.' I converted to to int64 explicitly at line 19 and that seemed to fix it: GROUP_IDX = GROUP_IDX.to(torch.int64) There is probably a better fix...

carbondrop-nick commented 1 month ago

Same issue in utils/loss.py. residue_index is defaulting to int32 for me, preventing later conversion to one hot. Fixed by making it explicit.

carbondrop-nick commented 1 month ago

With this int32/int64 issue fixed I was able to through the entire ipynb without error.

JSATacoTruck commented 1 month ago

Writing to confirm that I found the same issue in all_atom.py and /ofold/utils/loss.py - I was able to resolve it by setting all three instances of residue_index.new_tensor in the loss.py file to an explicit int64 type with `dtype=torch.int64'. Following this I was able to run the demo without errors arising.