Open proteneer opened 3 years ago
One hiccup is that for non-equilibrium simulations this will be ungodly painful, and one of the times where I wish we had the custom force support of OpenMM.
Big +1 on this!
For a given way to expose controllable dials within the potential energy function (e.g. distance offsets, parameter scales, etc.), there are still many ways to vary these dials as a function of lambda.
I like to think of a "protocol" as a vector-valued function of a single controllable scalar -- something like protocol(lam: float) -> control_params: array
, to be used in conjunction with a potential energy function u_controllable(x, control_params: array) -> float
... We can apply any specific choice of the protocol function to yield a new function u(x, lam) = u_controllable(x, protocol(lam))
suitable for free energy calculations.
In the case of a parametric family of optimizable protocols, we would allow this vector-valued function to depend also on some protocol_params
-- something like protocol(lam: float, protocol_params: array) -> array
.
A practical complication avoided by this "vector-valued function of lambda" picture is the difficulty of jointly tuning "lambda scheduling" and "lambda spacing." The assumption would be you're always going to drag the scalar lambda from 0 to 1 at a constant speed (lambda_schedule = np.linspace(0,1,n_steps)
), but as needed you could make some of the control_params
vary faster or slower as a function of lambda at different points.
In keeping my sanity, I'd like to make the nonbonded lambda scheduling a bit more flexible than the rather rigid structure that we have right now. I'm really getting tired of managing multi-stage setups and dealing with their derivatives. Ping me offline for more detail re: the jank involved in parameter interpolation and its derivatives
In particular for the Nonbonded potential, I'd like to be able to rescale
eps, sig, q, and w
independently of each other.Suppose we have some arbitrary, continuous, function
f(λ): R^1->R^1
whose domain is[0,1]
, subject to the codomain constraintf(0)=0
andf(1)=1
. Generally, we'd like to implementf(λ)
in python for full expressiveness, and its derivative can be evaluated trivially using jax. Note thatf(λ)
and its derivative need to be computed once, are flat vectors, and can be cached.For parameter interpolation, we have:
p = (1-f(λ)) p_src + f(λ) p_dst
Then:dU/dλ = dU/dp.dp/dλ = dU/dp.dp/df.df/dλ
The first termdp/df
is fixed, as before. For the softcore/4D decoupling part, our distances are computed as:U(r(f(λ)) = U(sqrt(k+f(λ)^2))
We can proceed similarly viadU/dλ = dU/dr.dr/dλ = dU/dr.dr/df.df/dλ
In the above, the various
f(λ)
,df/dλ
,dr/df
,df/dλ
would be computed once in the constructor, and re-used throughout. In particular, we'd have four, per-particle expressions for:f_eps(λ), f_sig(λ), f_q(λ), f_w(λ)
In addition to their derivatives:df_eps/dλ, df_sig/dλ, df_q/dλ, df_w/dλ