Open jdrusso opened 2 years ago
Some more details on what this might look like, and some of the implementation challenges
After performing optimization, the dimensionality-reduced features provided to the haMSM must be used as additional progress coordinates.
For SynD, this can be achieved by simply extending the pcoord definitions in the backmapping.
However, in GENERAL, this is kind of a tricky problem, because this has to be flexible to different propagators etc. In rough terms, this might look something like
def get_segment_coordinates(segment):
....
def get_extended_pcoord(segment, hamsm: msm_we.modelWE):
# Get the coordinates -- we should probably get this directly from the augmented H5,
# meaning the augmentation must run before the pcoord calculation
seg_coords = get_segment_coordinates(segment)
# This just wraps however WESTPA would normally get the pcoord
original_pcoord = segment.pcoord
# Get the new dimensions of the pcoord, from the dimensionality-reduced MSM featurization of this segment
new_extended_pcoord = hamsm.coordinates.transform(hamsm_modelWE.processCoordinates(iter_coords))
full_extended_pcoord = np.concatenate(original_pcoord, new_extended_pcoord)
return full_extended_pcoord
# Overload pcoord calculation with our wrapper
def wrapped(original_propagation, *args, **kwargs):
segments = original_propagation(*args, **kwargs)
for segment in segments:
new_pcoord = get_extended_pcoord(segment)
segment.pcoord = new_pcoord
wm_ops.propagate = wrapped(wm_ops.propagate)
# Note that get_pcoord actually returns a state, not a list of segments, so the naming is a little misleading
wm_ops.get_pcoord = wrapped(wm_ops.get_pcoord)
Maybe the optimization plugin can overload the wm_ops
methods with this wrapper after it runs?
The catches are:
The optimization plugin must extend the progress coordinate using the dimensionality reduction produced by the haMSM.
Currently, this is only supported for SynD, because it's easy to recompute a value for every pcoord.
However, a more generic implementation could look something like wrapping
get_pcoord
with a call tomodel.processCoordinates
, and returning the combination of the original pcoord and the new dimensionality-reduced coordinates.Recall, though, that get_pcoord isn't necessarily called during propagation..