choderalab / chiron

Differentiable Markov Chain Monte Carlo
https://github.com/choderalab/chiron/wiki
MIT License
14 stars 1 forks source link

MC move refactoring #20

Closed chrisiacovella closed 5 months ago

chrisiacovella commented 7 months ago

We need to figure out a good balance between between OOPing and code readability/extensibility.

Currently, we have a pretty significant hierarchy of inherited classes in terms of the MCMC moves.

The basic MetropolisDisplacementMove inherits from MetropolizedMove which inherits from MCMove, which inherits from StateUpdateMove. While some hierarchy is definitely going to be necessary, this might be a bit excessive from a code readability standpoint (even if it will ultimately reduce the amount of code needed for each move).

The MCMove class seems to sketch out a reasonably clear structure (I'll note isn't necessarily followed in the moves that build upon the MetropolizedMove class), as a starting point.

In MCMove there are a few key functions already in there:

I'll note the MCMCSampler class expects moves to have a run function, so adding that into the MCMove class would likely be a good plan (where run doesn't necessarily change much between different routines, as it essentially is calling a function that performs a the move n times).

In terms of a generally applicable and easy to extended structure...lets start from the run function for an MC displacement move.

This approach would only require a user to define compute_acceptance_probability and _generate_new_state. It would make the acceptance criteria much clearer (it should not require many logic statements with regards to checking the thermodynamic state variables, unlike the current apply function in the MetropolizedMove function).

New MCMove key functions:

For functions where we want to allow dynamics updates to parameters, such as the maximum volume scaling or maximum displacement to achieved a target acceptance probability, this likely can happen within the _generate_new_state function (before the state is actually generated).

I think best course of action will be to do this in a separate PR after #14 and #19 are merged.

chrisiacovella commented 7 months ago

Current plan will be do refactoring in #21 (as it includes new reporter changes) and skip merging of #14 .

The outcome of the discussion of the API for MC moves mostly follows what was outlined above. @wiederm will post the pseudo code.

wiederm commented 7 months ago

Effectively, we want to implement this for each MCMCMove:

(proposed_sampler_state, log_proposal_ratio,  proposed_thermodynamic_state) = self._propose(current_sampler_state, current_thermodynamic_state)  # log proposal ratio + proposal sampler state
current_reduced_pot = current_thermodynamic_state.get_reduced_potential(current_sampler_state)
proposed_reduced_pot = proposed_thermodynamic_state.get_reduced_potential(proposed_sampler_state)
decicion = self._accept_or_reject(current_reduced_pot, proposed_reduced_pot, log_proposal_ratio, method="metropolis",  # including the log acceptance ratio
if decicion:
     self._replace_states(proposed_sampler_state, proposed_thermodynamic_state)
chrisiacovella commented 7 months ago

Below I sketched out things a little more in detail. I think we agreed to name the main public function update and refer to each step as _step (but please correct me if I'm wrong). This will still need some work to streamline space operations (neighbor list updating and wrapping functions) but that can be fleshed out as code is written

A few questions:

class MCMove(MCMCMove):
    def __init__(self,
        atom_subset: Optional[List[int]] = None,
        nr_of_moves: int = 100,
        reporter: Optional[_SimulationReporter] = None,
        report_frequency: int = 1,
        method: str = "metropolis",
        )
        """
        Initialize the move  stats

        Parameters
        ----------
        atom_subset : List[int], optional, default=None
            List of atom indices to apply the move to. If None, the move will be applied to all atoms.
        nr_of_moves : int, optional, default=100
            Number of times to apply the move.
        reporter : _SimulationReporter, optional, default=None
            Reporter to save the simulation data.
        report_frequency : int, optional, default=1
            Frequency with which to report the simulation data.
        method : str, optional, default="metropolis"
            Methodology to use for accepting or rejecting the proposed state.
        """
        self.nr_of_moves = nr_of_moves
        self.method = method

    def update(self, sampler_state: SamplerState, thermodynamic_state: ThermodynamicState, nbr_list: Optional[PairsBase] = None):
        """
        Perform the defined move and update the state.
        """

        calculate_current_potential = True
        for i in range(self.nr_of_moves):
            self._step(self.sampler_state, self.thermodynamic_state, calculate_current_potential)
            calculate_current_potential = False

    def _step(self, current_sampler_state, current_thermodynamic_state, calculate_current_potential):

        # if this is the first time we are calling this,
        # we will need to recalculate the reduced potential for the current state
        if calculate_current_potential:
            current_reduced_pot = current_thermodynamic_state.get_reduced_potential(current_sampler_state)
        else:
            current_reduced_pot = self._current_reduced_pot

        # propose a new state and calculate the log proposal ratio
        # this will be specific to the type of move
        # in addition to the sampler_state, this will require/return the thermodynamic state
        # for systems that e.g., make changes to particle identity.
        proposed_sampler_state, log_proposal_ratio,  proposed_thermodynamic_state = (
            self._propose(current_sampler_state, current_thermodynamic_state))

        # calculate the reduced potential for the proposed state
        proposed_reduced_pot = proposed_thermodynamic_state.get_reduced_potential(proposed_sampler_state)

        # accept or reject the proposed state
        decision = self._accept_or_reject(current_reduced_pot, proposed_reduced_pot, log_proposal_ratio, method=self.method)

        if decision:
            # stash the reduced potential of the accepted state so we don't have to recalculate it the next iteration
            self._current_reduced_pot = proposed_reduced_pot

            # replace the current state with the proposed state
            # not sure this needs to be a separate function but for simplicity in outlining the code it is fine
            # or should this return the new sampler_state and thermodynamic_state?
            self._replace_states(current_sampler_state, proposed_sampler_state, current_thermodynamic_state, proposed_thermodynamic_state)

        # a function that will update the statistics for the move
        self._update_statistics(decision)

    def _propose(self, current_sampler_state, current_thermodynamic_state):
        """
        Propose a new state and calculate the log proposal ratio.

        This will need to be defined for each move

        Parameters
        ----------
        current_sampler_state : SamplerState, required
            Current sampler state.
        current_thermodynamic_state : ThermodynamicState, required

        Returns
        -------
        proposed_sampler_state : SamplerState
            Proposed sampler state.
        log_proposal_ratio : float
            Log proposal ratio.
        proposed_thermodynamic_state : ThermodynamicState
            Proposed thermodynamic state.

        """
        pass

    def _replace_states(self, current_sampler_state, proposed_sampler_state, current_thermodynamic_state, proposed_thermodynamic_state):
        """
        Replace the current state with the proposed state.
        """
        # define the code to copy the proposed state to the current state

    def _accept_or_reject(self, current_reduced_pot, proposed_reduced_pot, log_proposal_ratio, method=method):
        """
        Accept or reject the proposed state with a given methodology.
        """
        # define the acceptance probability
chrisiacovella commented 6 months ago

Since _propose will return the log_proposal_ratio, we do not need a separate call to get the proposed reduced potential. That should be within _propose. Similarly, this should return the proposed_reduced_potential and accept the current_reduced_potential (again, keeping track of these is simply to avoid having to calculate again). Also, as noted in #24 we should have this return sampler/thermodynamic states, rather than modifying in place.

def _step(self, current_sampler_state, current_thermodynamic_state, calculate_current_potential):

        # if this is the first time we are calling this,
        # we will need to recalculate the reduced potential for the current state
         if calculate_current_potential:
            current_reduced_pot = current_thermodynamic_state.get_reduced_potential(current_sampler_state)
        else:
            current_reduced_pot = self._current_reduced_pot

        # propose a new state and calculate the log proposal ratio
        # this will be specific to the type of move
        # in addition to the sampler_state, this will require/return the thermodynamic state
        # for systems that e.g., make changes to particle identity.
        proposed_sampler_state, log_proposal_ratio,  proposed_thermodynamic_state, proposed_reduce_potential = (
            self._propose(current_sampler_state, current_thermodynamic_state, current_reduced_potential))

        # accept or reject the proposed state
        decision = self._accept_or_reject(current_reduced_pot, proposed_reduced_pot, log_proposal_ratio, method=self.method)

        if decision:
            # stash the reduced potential of the accepted state so we don't have to recalculate it the next iteration
            self._current_reduced_pot = proposed_reduced_pot

            return proposed_sampler_state,  proposed_thermodynamic_state
chrisiacovella commented 6 months ago

In refactoring, a few other things have come up. New moves will only require users to define three functions:

chrisiacovella commented 6 months ago

should _step return a new sampler/thermo state or just update the ones passed? I think it either is fine, but I think if we update the ones that are passed, we should make sure to make a copy of the initial states that are passed, in case we want to just, say, reject the entire move (that is, if we run for 500 displacement steps, reject the whole trajectory).

As we mentioned in #24 and #23 , I think the idea will be to have functions accept and return sampler and thermodynamic states (and any other associated parameters we might want).

Two place where we don't follow this: 1- current the MCMCSampler itself just updates in place. This should probably be changed, too. This will require a little refactoring, since here the constructor takes the sampler/thermo state, as opposed to the run function.

2- Neighborlist is modified in place. I think we already discussed the idea of moving this into the sampler state to minimize the amount of stuff we need to pass around. This might be more efficient for some of the MC moves, because if we reject a move, it would basically automatically revert back, even if we had to rebuild the neighbor list in the proposed_state. I just need to think about a bit what is the easiest way to implement this (since neighbor list depends on the sampler state itself) (can probably wrap this so we don't have to explicitly pass the sampler state, just like sampler_state.build_nbr_list() and sampler_state.check() )

chrisiacovella commented 6 months ago

I changed the code around such that the neighbor list is returned along with sampler states and thermodynamic state. One place where this might be useful is in the case of a rejected MC move, since we can just step back to the prior neighbor list state, rather than have to do another check to make sure that the neighbor list modified by the proposed state still works.

I left the MCMC Sampler in place, as we will need to discuss that in the context of the multistate sampling.

The PR should be ready to go.