proteneer / timemachine

Differentiate all the things!
Other
138 stars 17 forks source link

Detecting and handling numerical overflows #481

Closed maxentile closed 1 year ago

maxentile commented 3 years ago

Currently, the potential energy function cannot return a nan or inf when numerical overflow is encountered, instead returning values around -1e+7.

This causes problems in two settings:

The current workaround in the Python layer is to compare abs(U) to an absolute guard_threshold of 1e+6. Justification for this threshold: For the smallest systems (such as a single LJ pair) as well as practical systems, overflows appear as values around -1e+7, while valid energies are at least 2 orders of magnitude away from this threshold. However, if we considered larger systems, then valid typical energies could get even closer to this threshold.

After some discussion surrounding https://github.com/proteneer/timemachine/pull/476#discussion_r682934819 and sanitize_energies, we may want to change this behavior.

Some considerations:

Some possible changes:

proteneer commented 3 years ago

+1 to this and thanks for the concise summary.

I'd really like to just directly address the overflow by preserving the sign bit somehow as opposed to masking things out. I think we need to overhaul the energy system a bit to be more precise when computing delta_Us (which are extensive w.r.t. number of ligand atoms as opposed to total # of atoms), and to be safer against overflows. I will do some tinkering in python land...

proteneer commented 3 years ago

Okay - I've identified the problem on the C++ layer, but I don't have a solution. Let's assume that we're not dealing with exclusions but the calculation still blows up when we evaluate energies at lambda=0 using conformations from lambda=1.0 during decoupling. This will generate states where the atoms will be colliding on top of each other.

We accumulate energies (and forces etc.) in fixed point by converting from float -> int64 -> uint64. For exclusions, this works without issues because even though overflow in the float-> int64 step may occur, as long as it's deterministic we will back-subtract and cancel out this term entirely. However, in the event that we have a true clash, the float -> int64 will lose the sign bit information entirely. In fact, many implementations will saturate the bit pattern to all 1s. So a very large positive energy may become negative. Furthermore, even if we did some how set the sign bit to be consistent, we will have the problem that the accumulation of int64s during an overflow will "wrap around" and flip the sign bit unintentionally.

#define FIXED_EXPONENT             0x1000000000

#include <iostream>
#include <cmath>

long long real_to_int64(double x) {
    return std::llrint(x);
}

template<typename RealType>
unsigned long long FLOAT_TO_FIXED(RealType v) {
    return static_cast<unsigned long long>(real_to_int64(v*FIXED_EXPONENT));
}

template<typename RealType>
RealType FIXED_TO_FLOAT(unsigned long long v) {
    return static_cast<RealType>(static_cast<long long>(v))/FIXED_EXPONENT;
}

void display(float x) {
    std::cout << "x: " << x << " llrint(x*FIXED_EXPONENT): " << std::llrint(x*FIXED_EXPONENT) << std::endl;
}

int main() {

    for (long i=0x1000000; i <= 0x10000000; i*=2) {
        display(i);
    }

    long long x = 4611686018427387904;

    // wrong sign in signed int64s
    std::cout << "sum " << x + x + x << std::endl;

    // wrong sign even if we convert to unsigned int64s during the sum.
    std::cout << "sum " << static_cast<long long>(static_cast<unsigned long long>(x) + static_cast<unsigned long long>(x) + static_cast<unsigned long long>(x)) << std::endl;

}

Will result in:

x: 1.67772e+07 llrint(x*FIXED_EXPONENT): 1152921504606846976
x: 3.35544e+07 llrint(x*FIXED_EXPONENT): 2305843009213693952
x: 6.71089e+07 llrint(x*FIXED_EXPONENT): 4611686018427387904
x: 1.34218e+08 llrint(x*FIXED_EXPONENT): -9223372036854775808
x: 2.68435e+08 llrint(x*FIXED_EXPONENT): -9223372036854775808
sum -4611686018427387904
sum -4611686018427387904
proteneer commented 3 years ago

PS OpenMM doesn't have this problem, as its energies are accumulated in floating point. However this has other problems (eg. the barostat and MC moves that depend on energy would no longer be deterministic). To be clear, this is a problem currently limited to MBAR, and not BAR, since BAR would typically be used to evaluate energies at adjacent windows.

proteneer commented 2 years ago

After some discussion with @maxentile , it looks like we have a fairly clean way finally of addressing this, at least in the following two use cases:

1) Using samples generated where the potential energy sets lambda=1 (non-interacting), and evaluating each sample with clashes by using a potential energy with lambda=0 (interacting). 2) Re-parameterizing a ligand parameter, eg. sigma in the lennard jones potential, that would result in a clash with either the ligand or the host.

We can solve the above issues by noting that the nonbonded kernel can be separated into three parts:

U_T = U_H + U_HX + U_X

where U_H denotes the host-host interactions, U_HX denotes the host-ligand interactions, U_X denotes the ligand-ligand interactions. Exclusions are only present for the U_H and U_X case, and never the U_HX case.

For U_HX, the number of tiles involved scales only with respect to the size of the ligand and the cutoff. Practically, this means that we can deterministically compute a per-tile energy, and deterministically accumulate using a parallel reduction. Note that this does not help with the force computation, since a per-tile force may be a prohibitively large buffer, @maxentile is there a reasonable use-case for getting the sign bit of the forces correct?

For U_X, we can process this as a pairlist, since the number of interactions will be quite small, and we can remove the exclusions associated with these interactions explicitly. Again, we can proceed via a parallel reduction.

For U_H, we use the old fixed-point accumulator, but since U_H is always well behaved for scenarios 1) and 2), we do not need to worry about them.

maxentile commented 2 years ago

Note that this does not help with the force computation, since a per-tile force may be a prohibitively large buffer, @maxentile is there a reasonable use-case for getting the sign bit of the forces correct?

I can't think of one! If we can detect that the energy is nan or +inf, we wouldn't use the associated forces for anything.

proteneer commented 2 years ago

To add more commentary to this after debugging with @mcwitt. There three two sources of overflow/UB in our fixed point implementation:

1) casting from a float/double to signed long long 2) casting from an unsigned long long to signed long long 3) summing multiple signed long long together

So:

1) we can detect via a guard, eg. |x| > 2^63 and check the sign manually, so when we overflow we saturate the sign bit correctly. By default, when overflowing (either positive or negative), gcc seems to saturate every bit. 2) while this is technically UB, it seems that we can at least guarantee static_cast<long long>(static_cast<unsigned long long>(x)) = x for g++/cuda? 3) don't have a solution for this still

maxentile commented 2 years ago

Tacking a low-priority task onto this issue: We also want to add documentation of the supported dynamic ranges for parameter derivatives, probably near these constants: https://github.com/proteneer/timemachine/blob/5b34f421dacf7eaf2d7c1a14e4808c1760f0bc69/timemachine/cpp/src/fixed_point.hpp#L5-L8 (Migrated from https://github.com/proteneer/timemachine/pull/556#discussion_r782641615 .)

maxentile commented 1 year ago

A minor consequence of this issue is that the error message when a simulation blows up is slightly cryptic (because this error detection is dependent on the barostat).

--> Perhaps we can detect simulation instability by checking some other condition -- perhaps force norm exceeding this constant https://github.com/proteneer/timemachine/blob/6a1d8e21c4833565b06cee1489b47992bb45bc0e/timemachine/constants.py#L23-L24

maxentile commented 1 year ago

Just documenting that this issue still precludes the use of basic MC moves in timemachine.

Illustration with random walk Metropolis: https://gist.github.com/maxentile/56661d087919d38b0693a44794f3ccab

(Setup: Collect approximate equilibrium samples for a small system (random FreeSolv molecule in a waterbox), and estimate the acceptance fraction of random walk Metropolis moves starting from these samples, as a function of the random walk proposal standard deviation. x_prop = x + proposal_stddev * np.random.randn(*x.shape); accept_prob = min(1, exp(-u(x_prop) - (-u(x)))), average over many realizations where x is drawn from equilibrium.)

Scanning over proposal_stddevs logarithmically spaced between 1e-7 and 1e+1 nm, we observe a plot something like this:

image

where the apparent acceptance rate drops below 1e-100 for proposals of size ~0.001, but spuriously reverts up to 50% when the proposals are larger and clashier.

This is partially mitigated by using a guard threshold of 1e+6 kJ/mol as in https://github.com/proteneer/timemachine/issues/481#issue-964051777, but this still allows a small fraction (~0.5%) of clash-inducing moves to be accepted:

image

(If iterated, accepting even a single clash-inducing move would be expected to corrupt all subsequent iterations.)

maxentile commented 1 year ago

In upcoming spring cleaning week I'll add an option to compute energies in a way that's slower, but safe w.r.t. integer overflow in the sums alluded to in case 3 of https://github.com/proteneer/timemachine/issues/481#issuecomment-1010549585 (when accumulator += pair_interaction would overflow, and accumulator > 0, pair_interaction > 0).

Note there's also a 4th case that may be important to handle: when an individual pair_interaction has already overflown (i.e. when an individual pair is clashy U(i, j) >> 0, but U_fixed_point(i, j) < 0).

(The new option will be disabled by default (check_overflow = false) to avoid introducing performance regressions in applications that do not need to check for overflow (such as production MD).)

maxentile commented 1 year ago

In upcoming spring cleaning week I'll add an option to compute energies in a way that's slower, but safe w.r.t. integer overflow

Wasn't able to get to this when planned, but a code path that allows computing energies in a way that is safe w.r.t. integer overflow would still be useful in the future, to be able to perform Monte Carlo simulations within timemachine.

Two other observations:

Since exclusions will cause expected overflows, we will need to be able to detect whether we are currently processing an exclusion.

("Cancellation of NaNs" (introduced in https://github.com/proteneer/timemachine/pull/273 ): the nonbonded computation considers all possibly-within-cutoff pairs (without attempting to detect whether the pair is an exclusion), and then computes the contributions from pairs in the exclusion list a second time, and subtracts these contributions.)

This suggests defining a small structure for each i that allows very quickly testing whether its interaction with j is an exclusion, roughly exclusion_bitsets[i].get(j - i) (where each bitset may be implemented using just one or a few ULLs).

maxentile commented 1 year ago

This has been resolved by a heroic effort from @badisa and @proteneer in https://github.com/proteneer/timemachine/pull/1090 .

These changes should greatly simplify the implementation and analysis of MC moves, replica exchange, importance weights, and other applications that may need to compute U(x) on unfavorable proposed configurations x.

Any remaining points of numerical-overflow-related risk in timemachine seem negligible in my opinion: (1) Possible corner cases involving -inf energy components (which we should rule out at the model definition level!) (2) Possible precision loss or undetected overflows in force accumulation (which seems unlikely to ever be a limiting factor in practice, since we don't have any examples of this causing problems, the force accumulation closely matches OpenMM's widely-relied-upon implementation, exact samplers can be implemented even with approximate forces, ...).)