QInfer / python-qinfer

Library for Bayesian inference via sequential Monte Carlo for quantum parameter estimation.
BSD 3-Clause "New" or "Revised" License
92 stars 31 forks source link

simulate_experiment with domain error when n_outcomes is not constant #149

Open PengDauan opened 5 years ago

PengDauan commented 5 years ago

I was following the examples file:introduction_to_bayes_smc, I find the following issue when I was trying to call the simulate_experiment methods 20181119020922

PengDauan commented 5 years ago

class FiniteOutcomeModel(Model):

def domain(self, expparams):
    if self.is_n_outcomes_constant:
        return self._domain if expparams is None else [self._domain for ep in expparams]
    else:
        return [IntegerDomain(min=0,max=n_o-1) for n_o in self.n_outcomes(expparams)]

def simulate_experiment(self, modelparams, expparams, repeat=1):

    super(FiniteOutcomeModel, self).simulate_experiment(modelparams, expparams, repeat)

    if self.is_n_outcomes_constant:
        all_outcomes = self.domain(None).values
        probabilities = self.likelihood(all_outcomes, modelparams, expparams)
        cdf = np.cumsum(probabilities, axis=0)
        randnum = np.random.random((repeat, 1, modelparams.shape[0], expparams.shape[0]))
        outcome_idxs = all_outcomes[np.argmax(cdf > randnum, axis=1)]
        outcomes = all_outcomes[outcome_idxs]
    else:
        assert(self.are_expparam_dtypes_consistent(expparams))
        dtype = self.domain(expparams[0, np.newaxis])[0].dtype
        outcomes = np.empty((repeat, modelparams.shape[0], expparams.shape[0]), dtype=dtype)
        for idx_experiment, single_expparams in enumerate(expparams[:, np.newaxis]):
            all_outcomes = self.domain(single_expparams).values
            probabilities = self.likelihood(all_outcomes, modelparams, single_expparams)
            cdf = np.cumsum(probabilities, axis=0)[..., 0]
            randnum = np.random.random((repeat, 1, modelparams.shape[0]))
            outcomes[:, :, idx_experiment] = all_outcomes[np.argmax(cdf > randnum, axis=1)]

    return outcomes[0, 0, 0] if repeat == 1 and expparams.shape[0] == 1 and modelparams.shape[0] == 1 else outcomes
PengDauan commented 5 years ago
def simulate_experiment(self, modelparams, expparams, repeat=1):
    # Call the superclass simulate_experiment, not recording the result.
    # This is used to count simulation calls.
    super(FiniteOutcomeModel, self).simulate_experiment(modelparams, expparams, repeat)

    if self.is_n_outcomes_constant:
        # In this case, all expparams have the same domain
        all_outcomes = self.domain(None).values
        probabilities = self.likelihood(all_outcomes, modelparams, expparams)
        cdf = np.cumsum(probabilities, axis=0)
        randnum = np.random.random((repeat, 1, modelparams.shape[0], expparams.shape[0]))
        outcome_idxs = all_outcomes[np.argmax(cdf > randnum, axis=1)]
        outcomes = all_outcomes[outcome_idxs]
    else:
        # Loop over each experiment, sadly.
        # Assume all domains have the same dtype
        assert(self.are_expparam_dtypes_consistent(expparams))
        dtype = self.domain(expparams[0, np.newaxis])[0].dtype
        outcomes = np.empty((repeat, modelparams.shape[0], expparams.shape[0]), dtype=dtype)
        for idx_experiment, single_expparams in enumerate(expparams[:, np.newaxis]):
            all_outcomes = self.domain(single_expparams).values
            probabilities = self.likelihood(all_outcomes, modelparams, single_expparams)
            cdf = np.cumsum(probabilities, axis=0)[..., 0]
            randnum = np.random.random((repeat, 1, modelparams.shape[0]))
            outcomes[:, :, idx_experiment] = all_outcomes[np.argmax(cdf > randnum, axis=1)]

    return outcomes[0, 0, 0] if repeat == 1 and expparams.shape[0] == 1 and modelparams.shape[0] == 1 else outcomes

change the code in abstract.FiniteOutcomeModel.simulate_experiment all_outcomes= self.domain(single_expparams).values to all_outcomes= self.domain(single_expparams)[idx_experiment].values will solve the problem in this case, but I'm not sure this is the general case