Open Chronum94 opened 3 weeks ago
Some more tests in 3D, since my previous comment on 0.1-1.0 is actually wrong. and is likely actually that scaled by the number of atoms (and can in principle be learned or kept as a user-input parameter. Some code:
from jax.ops import segment_sum, segment_min
import jax.numpy as jnp
import numpy as np
from matscipy.neighbours import neighbour_list
from scipy.stats.qmc import PoissonDisk
import matplotlib.pyplot as plt
np.random.seed(35280)
sampler = PoissonDisk(3)
points = sampler.fill_space()
num_points = len(points)
a = jnp.array(points)
i, j, dr = neighbour_list('ijd', cutoff=0.2, positions=np.array(a))
dr_mins = segment_min(dr, i, num_points)
repeats = jnp.bincount(i)
dr_mins = jnp.repeat(dr_mins, repeats)
for exp_denom in np.logspace(-4, -2, 3):
softmaxed_dr = jnp.exp(-(dr - dr_mins) / exp_denom)
softmax_weighted_dr = dr * softmaxed_dr / jnp.repeat(segment_sum(softmaxed_dr, i, num_points), repeats)
strict_min_dr = segment_min(dr, i, num_points)
soft_min_dr = segment_sum(softmax_weighted_dr, i, num_points)
plt.scatter(strict_min_dr, soft_min_dr, s=1, label=f"{exp_denom:0.1e}")
plt.legend()
plt.show()
@M-R-Schaefer and I have discussed this a small amount, and the motivation for this idea is that, given an energy that looks something like so:
is by itself likely not enough to enforce certain corrections. For example, if
E_corr
is isolated atom energy shifts, an isolated atom will still have the energy that looks like so:which is only the correct in the limit of E_NN being exactly zero when an atom has nothing near it. This need not be true in the current model, since one can imagine that even with all moment amplitudes being 0, without (and perhaps with) sufficient data available, E_NN can do whatever it likes (because of bias terms).
One possible way to improve this behaviour is to suppress E_NN and make the contribution purely E_corr (E_isolated is specifically what I had in mind) using some measure of "isolation" of an atom. So something that looks like so:
Such that when
switching_f_value
goes to 0.0, by constructionE
can only beE_corr
.One possible switching function (and its input) can be
switch(dr): 1.0 - cosine_cutoff(min(dr))
wheredr_min
is the minimum distance to a neighbour for an atom. This can be done withjax.ops.segment_min
.However, this function has a cusp condition whenever two atoms switch minimum distances.
To remedy this, one can imagine a 'smooth min' function which smoothly switches the minimum of two atoms whenever two atoms are close by. One possible way to do this is by hacking a softmax function like so:
which gives an output that looks like so:![image](https://github.com/apax-hub/apax/assets/6704979/4caf75a1-061d-4be6-ae99-b2a6413c1df1)
This function (or this hacky way of writing this) switches smoothly between one atom being close to another atom being close. In the limit of atoms being infinitely separated, it weights the nearest one exclusively. As long as
ces
(characteristic <I've forgotten> scale) is a small number, for almost all values of radial distance, the distance of the nearest atom is a sufficiently good approximation to a minimum. A good value here would likely be something like 0.1-1.0 angstrom, but chances are that final results will be insensitive to this.Now our final expression would look something like:
Any switching function works, of course, even a learnable one.
The above examples is of course for a 2-neighbour-to-atom setup for convenience, and here I'm presenting the relevant lines of code to hopefully implement this somewhere in apax if this is chosen to be implemented.
Which should give this output:
The smooth min should thus be bounded from below by (segment)_min. But it's also smooth. This seems like a good compromise.
I'd love to hear thoughts on this and any questions about the motivations of this that I can clarify.