proteneer / timemachine

Differentiate all the things!
Other
140 stars 17 forks source link

Allow lambda scheduling to be more flexible. #402

Open proteneer opened 3 years ago

proteneer commented 3 years ago

In keeping my sanity, I'd like to make the nonbonded lambda scheduling a bit more flexible than the rather rigid structure that we have right now. I'm really getting tired of managing multi-stage setups and dealing with their derivatives. Ping me offline for more detail re: the jank involved in parameter interpolation and its derivatives

In particular for the Nonbonded potential, I'd like to be able to rescale eps, sig, q, and w independently of each other.

Suppose we have some arbitrary, continuous, function f(λ): R^1->R^1 whose domain is [0,1], subject to the codomain constraint f(0)=0 and f(1)=1. Generally, we'd like to implement f(λ) in python for full expressiveness, and its derivative can be evaluated trivially using jax. Note that f(λ) and its derivative need to be computed once, are flat vectors, and can be cached.

For parameter interpolation, we have: p = (1-f(λ)) p_src + f(λ) p_dst Then: dU/dλ = dU/dp.dp/dλ = dU/dp.dp/df.df/dλ The first term dp/df is fixed, as before. For the softcore/4D decoupling part, our distances are computed as: U(r(f(λ)) = U(sqrt(k+f(λ)^2)) We can proceed similarly via dU/dλ = dU/dr.dr/dλ = dU/dr.dr/df.df/dλ

In the above, the various f(λ), df/dλ, dr/df, df/dλ would be computed once in the constructor, and re-used throughout. In particular, we'd have four, per-particle expressions for: f_eps(λ), f_sig(λ), f_q(λ), f_w(λ) In addition to their derivatives: df_eps/dλ, df_sig/dλ, df_q/dλ, df_w/dλ

proteneer commented 3 years ago

One hiccup is that for non-equilibrium simulations this will be ungodly painful, and one of the times where I wish we had the custom force support of OpenMM.

maxentile commented 3 years ago

Big +1 on this!

For a given way to expose controllable dials within the potential energy function (e.g. distance offsets, parameter scales, etc.), there are still many ways to vary these dials as a function of lambda.

I like to think of a "protocol" as a vector-valued function of a single controllable scalar -- something like protocol(lam: float) -> control_params: array, to be used in conjunction with a potential energy function u_controllable(x, control_params: array) -> float... We can apply any specific choice of the protocol function to yield a new function u(x, lam) = u_controllable(x, protocol(lam)) suitable for free energy calculations.

In the case of a parametric family of optimizable protocols, we would allow this vector-valued function to depend also on some protocol_params -- something like protocol(lam: float, protocol_params: array) -> array.

A practical complication avoided by this "vector-valued function of lambda" picture is the difficulty of jointly tuning "lambda scheduling" and "lambda spacing." The assumption would be you're always going to drag the scalar lambda from 0 to 1 at a constant speed (lambda_schedule = np.linspace(0,1,n_steps)), but as needed you could make some of the control_params vary faster or slower as a function of lambda at different points.