apax-hub / apax

A flexible and performant framework for training machine learning potentials.
MIT License
12 stars 1 forks source link

[ENH] Switching function to suppress NN output and go to purely scale-shift correction when no atoms nearby #291

Open Chronum94 opened 3 weeks ago

Chronum94 commented 3 weeks ago

@M-R-Schaefer and I have discussed this a small amount, and the motivation for this idea is that, given an energy that looks something like so:

E = E_NN + E_corr

is by itself likely not enough to enforce certain corrections. For example, if E_corr is isolated atom energy shifts, an isolated atom will still have the energy that looks like so:

E = E_NN + E_isolated

which is only the correct in the limit of E_NN being exactly zero when an atom has nothing near it. This need not be true in the current model, since one can imagine that even with all moment amplitudes being 0, without (and perhaps with) sufficient data available, E_NN can do whatever it likes (because of bias terms).

One possible way to improve this behaviour is to suppress E_NN and make the contribution purely E_corr (E_isolated is specifically what I had in mind) using some measure of "isolation" of an atom. So something that looks like so:

E = E_NN * switching_f_value + E_corr * (1 - switching_f_value)

Such that when switching_f_value goes to 0.0, by construction E can only be E_corr.

One possible switching function (and its input) can be switch(dr): 1.0 - cosine_cutoff(min(dr)) where dr_min is the minimum distance to a neighbour for an atom. This can be done with jax.ops.segment_min.

However, this function has a cusp condition whenever two atoms switch minimum distances.

To remedy this, one can imagine a 'smooth min' function which smoothly switches the minimum of two atoms whenever two atoms are close by. One possible way to do this is by hacking a softmax function like so:

import numpy as np
import matplotlib.pyplot as plt

x1 = np.linspace(1, 10, 1000)
x2 = np.linspace(10, 1, 1000)

minx = np.minimum(x1, x2)
plt.plot(x1, x1, 'k', x1, x2, 'k')

for ces in np.linspace(0.1, 5, 10):
    denom = np.exp(- (x1  - minx) / ces) + np.exp(-  (x2 - minx) / ces)
    plt.plot(x1, (x1 * np.exp(- (x1 - minx) / ces) + x2 * np.exp(-(x2 - minx) / ces)) / denom)

which gives an output that looks like so: image

This function (or this hacky way of writing this) switches smoothly between one atom being close to another atom being close. In the limit of atoms being infinitely separated, it weights the nearest one exclusively. As long as ces (characteristic <I've forgotten> scale) is a small number, for almost all values of radial distance, the distance of the nearest atom is a sufficiently good approximation to a minimum. A good value here would likely be something like 0.1-1.0 angstrom, but chances are that final results will be insensitive to this.

Now our final expression would look something like:

E = E_NN * cosine_cutoff(smooth_min(dr)) +  E_corr * (1.0 - cosine_cutoff(smooth_min(dr)))

Any switching function works, of course, even a learnable one.

The above examples is of course for a 2-neighbour-to-atom setup for convenience, and here I'm presenting the relevant lines of code to hopefully implement this somewhere in apax if this is chosen to be implemented.

from jax.ops import segment_sum, segment_min
import jax.numpy as jnp
import numpy as np

from matscipy.neighbours import neighbour_list

np.random.seed(35280)
num_points = 6
a = jnp.array(np.random.rand(num_points, 3))
i, j, dr = neighbour_list('ijd', cutoff=1.0, positions=np.array(a))
dr_mins = segment_min(dr, i, num_points)
repeats = jnp.bincount(i)
dr_mins = jnp.repeat(dr_mins, repeats)
softmaxed_dr = jnp.exp(-(dr - dr_mins) / 0.05)
softmax_weighted_dr = dr * softmaxed_dr / jnp.repeat(segment_sum(softmaxed_dr, i, num_points), repeats)
print(segment_min(dr, i, num_points))
print(segment_sum(softmax_weighted_dr, i, num_points))

Which should give this output:

[0.12716725 0.27708408 0.12716725 0.45081526 0.41279438 0.27708408]
[0.12740156 0.29144165 0.12797245 0.46517906 0.43096972 0.28029796]

The smooth min should thus be bounded from below by (segment)_min. But it's also smooth. This seems like a good compromise.

I'd love to hear thoughts on this and any questions about the motivations of this that I can clarify.

Chronum94 commented 3 weeks ago

Some more tests in 3D, since my previous comment on 0.1-1.0 is actually wrong. and is likely actually that scaled by the number of atoms (and can in principle be learned or kept as a user-input parameter. Some code:

from jax.ops import segment_sum, segment_min
import jax.numpy as jnp
import numpy as np

from matscipy.neighbours import neighbour_list

from scipy.stats.qmc import PoissonDisk

import matplotlib.pyplot as plt

np.random.seed(35280)

sampler = PoissonDisk(3)
points = sampler.fill_space()
num_points = len(points)
a = jnp.array(points)
i, j, dr = neighbour_list('ijd', cutoff=0.2, positions=np.array(a))
dr_mins = segment_min(dr, i, num_points)
repeats = jnp.bincount(i)
dr_mins = jnp.repeat(dr_mins, repeats)

for exp_denom in np.logspace(-4, -2, 3):
    softmaxed_dr = jnp.exp(-(dr - dr_mins) / exp_denom)
    softmax_weighted_dr = dr * softmaxed_dr / jnp.repeat(segment_sum(softmaxed_dr, i, num_points), repeats)
    strict_min_dr = segment_min(dr, i, num_points)
    soft_min_dr = segment_sum(softmax_weighted_dr, i, num_points)

    plt.scatter(strict_min_dr, soft_min_dr, s=1, label=f"{exp_denom:0.1e}")
plt.legend()
plt.show()

image