jdrusso / msm_we

History-augmented Markov analysis of weighted ensemble trajectories.
https://msm-we.readthedocs.io
MIT License
7 stars 7 forks source link

Make optimization pcoord extension generic, not SynD specific #28

Open jdrusso opened 2 years ago

jdrusso commented 2 years ago

The optimization plugin must extend the progress coordinate using the dimensionality reduction produced by the haMSM.

Currently, this is only supported for SynD, because it's easy to recompute a value for every pcoord.

However, a more generic implementation could look something like wrapping get_pcoord with a call to model.processCoordinates, and returning the combination of the original pcoord and the new dimensionality-reduced coordinates.

Recall, though, that get_pcoord isn't necessarily called during propagation..

jdrusso commented 1 year ago

Some more details on what this might look like, and some of the implementation challenges

After performing optimization, the dimensionality-reduced features provided to the haMSM must be used as additional progress coordinates.

For SynD, this can be achieved by simply extending the pcoord definitions in the backmapping.

However, in GENERAL, this is kind of a tricky problem, because this has to be flexible to different propagators etc. In rough terms, this might look something like

def get_segment_coordinates(segment):
    ....

def get_extended_pcoord(segment, hamsm: msm_we.modelWE):

    # Get the coordinates -- we should probably get this directly from the augmented H5, 
    #   meaning the augmentation must run before the pcoord calculation
    seg_coords = get_segment_coordinates(segment)

    # This just wraps however WESTPA would normally get the pcoord
    original_pcoord = segment.pcoord

    # Get the new dimensions of the pcoord, from the dimensionality-reduced MSM featurization of this segment
    new_extended_pcoord = hamsm.coordinates.transform(hamsm_modelWE.processCoordinates(iter_coords))

    full_extended_pcoord = np.concatenate(original_pcoord, new_extended_pcoord)

    return full_extended_pcoord

# Overload pcoord calculation with our wrapper
def wrapped(original_propagation, *args, **kwargs):

    segments = original_propagation(*args, **kwargs)

    for segment in segments:
        new_pcoord = get_extended_pcoord(segment)
        segment.pcoord = new_pcoord

wm_ops.propagate = wrapped(wm_ops.propagate)

# Note that get_pcoord actually returns a state, not a list of segments, so the naming is a little misleading
wm_ops.get_pcoord = wrapped(wm_ops.get_pcoord)

Maybe the optimization plugin can overload the wm_ops methods with this wrapper after it runs?

The catches are: