California-Planet-Search / radvel

General Toolkit for Modeling Radial Velocity Data
http://radvel.readthedocs.io
MIT License
58 stars 53 forks source link

Draw models from chain #324

Closed paulmrobertson closed 4 years ago

paulmrobertson commented 4 years ago

This utility function appears to be outdated. A few issues:

  1. chain.ix should be replaced with chain.iloc, I think.

  2. The model has no attribute vary_parameters

bjfultn commented 4 years ago

Thanks for the report! Do you have a working piece of code? If so, I can work that in if you aren't comfortable doing a PR.

paulmrobertson commented 4 years ago

OK, I think I've got it working. Here's the new version of draw_models_from_chain

def draw_models_from_chain(mod, chain, t, nsamples=50):
    """Draw Models from Chain

    Given an MCMC chain of parameters, draw representative parameters
    and synthesize models.

    Args:
        mod (radvel.RVmodel) : RV model
        chain (DataFrame): pandas DataFrame with different values from MCMC
            chain
        t (array): time range over which to synthesize models
        nsamples (int): number of draws

    Returns:
        array: 2D array with the different models as different rows
    """

    np.random.seed(0)
    chain_samples = chain.iloc[np.random.choice(chain.index, nsamples)]
    models = []
    for i in chain_samples.index:
        params = np.array(chain.loc[i, mod.list_vary_params()])
        params = mod.array_to_params(params)
        models += [mod(t)]
    models = np.vstack(models)
    return models
paulmrobertson commented 4 years ago

Meanwhile, though, you have to add a bunch of utility functions to the GeneralRVModel class:

    def array_to_params(self,param_values):

        new_params = self.params

        vary_parameters = self.list_vary_params()

        for i in range(len(vary_parameters)):
            new_params[vary_parameters[i]] = Parameter(value=param_values[i])

        return new_params

    def list_vary_params(self):
        keys = self.list_params()

        return [key for key in keys if self.params[key].vary]

    def list_params(self):
        try:
            keys = self.params_order
        except AttributeError:
            keys = list(self.params.keys())
            self.params_order = keys
        return keys
bjfultn commented 4 years ago

Incorporated in PR #388 and will be released in v1.4.2