Add memoization of dihedral_pairs in datasets such that they are only computed in the first epoch and then stored in memory and reused. This should speed up the code since computing the dihedral pairs previously took up 73% of the runtime in my experiments. Now, this overhead will only happen in the first epoch, and the additional memory usage is negligible.
Calling the attribute of the PyTorch geometric Data object edge_index_dihedral_pairs has the dihedral_pairs being treated as edge indices during batching such that PyTorch geometric automatically takes care of increasing the indices of the dihedral_pairs according to the graph sizes when creating a batch.
Add memoization of dihedral_pairs in datasets such that they are only computed in the first epoch and then stored in memory and reused. This should speed up the code since computing the dihedral pairs previously took up 73% of the runtime in my experiments. Now, this overhead will only happen in the first epoch, and the additional memory usage is negligible.
Calling the attribute of the PyTorch geometric Data object edge_index_dihedral_pairs has the dihedral_pairs being treated as edge indices during batching such that PyTorch geometric automatically takes care of increasing the indices of the dihedral_pairs according to the graph sizes when creating a batch.