Luthaf / rascaline

Computing representations for atomistic machine learning
https://luthaf.fr/rascaline/
BSD 3-Clause "New" or "Revised" License
44 stars 13 forks source link

Sample selection is extremely slow for NeighborList #319

Closed Luthaf closed 2 months ago

Luthaf commented 3 months ago

This is related to https://github.com/lab-cosmo/metatensor/issues/700. I also tried to do the selection using selected_samples, but this brought the NL calculation from ~3s to 30s on my machine.

import rascaline
import metatensor as mts

import numpy as np
import ase.io

frame_iron = ase.io.read("iron-snapshot.xyz", 0)

# selects a subset of the atoms to speed up and remove memory requirements
crop_dz = 16  
crop_idx = np.where(
    (frame_iron.positions[:, 2] < frame_iron.cell[2, 2] / 2 + crop_dz / 2)
    & (frame_iron.positions[:, 2] > frame_iron.cell[2, 2] / 2 - crop_dz / 2)
)[0]
frame_iron = frame_iron[crop_idx]

core_dz = 4
max_cutoff = 6
core_idx = np.sort(
    np.where(
        (frame_iron.positions[:, 0] > max_cutoff + 1)
        & (frame_iron.positions[:, 0] < 199 - max_cutoff)
        & (frame_iron.positions[:, 1] > max_cutoff + 1)
        & (frame_iron.positions[:, 1] < 199 - max_cutoff)
        & (frame_iron.positions[:, 2] < frame_iron.cell[2, 2] / 2 + core_dz / 2)
        & (frame_iron.positions[:, 2] > frame_iron.cell[2, 2] / 2 - core_dz / 2)
    )[0]
)
frame_core = frame_iron[core_idx]

# This is fast-ish
nl_code = rascaline.NeighborList(cutoff=max_cutoff, full_neighbor_list=True)
nl_all = nl_code.compute(frame_iron)

# This is very slow
# %%time
selected_samples = mts.Labels(names=["first_atom"], values=core_idx[:,np.newaxis])

nl_code = rascaline.NeighborList(cutoff=max_cutoff, full_neighbor_list=True)
nl_all = nl_code.compute(frame_iron, selected_samples=selected_samples)
Luthaf commented 3 months ago

Using the built-in profiler, I get the following timings:

Without sample selection:

╔════╦══════════════════════════════════════════════╦════════════╦═══════════╦══════════╦══════════╗
║ id ║ span name                                    ║ call count ║ called by ║ total    ║ mean     ║
╠════╬══════════════════════════════════════════════╬════════════╬═══════════╬══════════╬══════════╣
║  1 ║ Calculator::prepare                          ║          1 ║         — ║ 942.11ms ║ 942.11ms ║
╠════╬══════════════════════════════════════════════╬════════════╬═══════════╬══════════╬══════════╣
║  0 ║ NeighborsList                                ║          1 ║         1 ║ 342.48ms ║ 342.48ms ║
╠════╬══════════════════════════════════════════════╬════════════╬═══════════╬══════════╬══════════╣
║  2 ║ NeighborList::compute                        ║          1 ║         — ║    1.98s ║    1.98s ║
╚════╩══════════════════════════════════════════════╩════════════╩═══════════╩══════════╩══════════╝

With sample selection:

╔════╦══════════════════════════════════════════════╦════════════╦═══════════╦══════════╦══════════╗
║ id ║ span name                                    ║ call count ║ called by ║ total    ║ mean     ║
╠════╬══════════════════════════════════════════════╬════════════╬═══════════╬══════════╬══════════╣
║  1 ║ Calculator::prepare                          ║          1 ║         — ║   29.83s ║   29.83s ║
╠════╬══════════════════════════════════════════════╬════════════╬═══════════╬══════════╬══════════╣
║  0 ║ NeighborsList                                ║          1 ║         1 ║ 368.64ms ║ 368.64ms ║
╠════╬══════════════════════════════════════════════╬════════════╬═══════════╬══════════╬══════════╣
║  2 ║ NeighborList::compute                        ║          1 ║         — ║ 740.14ms ║ 740.14ms ║
╚════╩══════════════════════════════════════════════╩════════════╩═══════════╩══════════╩══════════╝