choderalab / modelforge

Infrastructure to implement and train NNPs
https://modelforge.readthedocs.io/en/latest/
MIT License
11 stars 4 forks source link

_accumulate dropped in pytorch 2.3 #124

Closed chrisiacovella closed 4 months ago

chrisiacovella commented 4 months ago

The random_record_split function relies on _accumulate from torch._utils to get the offset indices. PyTorch 2.3 has removed this function.

https://github.com/choderalab/modelforge/blob/90382d84cb0c8e813cb9e768d7a6326388d36bed/modelforge/dataset/utils.py#L465

 indices_by_split: List[List[int]] = []
    for offset, length in zip(torch._utils._accumulate(lengths), lengths):
        indices = []
        for record_idx in record_indices[offset - length : offset]:
            indices.extend(dataset.get_series_mol_idxs(record_idx))
        indices_by_split.append(indices)

We are just using this to get the cumulative sum, so changing this to use np.cumsum(lengths) should work equivalently. I will make this change in PR #123

wiederm commented 4 months ago

Closing this since it has been addressed in PR #123