g-walley / cegpy

Cegpy (/segpaɪ/) is a Python package for working with Chain Event Graphs. It supports learning the graphical structure of a Chain Event Graph from data, encoding of parametric and structural priors, estimating its parameters, and performing inference.
MIT License
10 stars 2 forks source link

Issue when defining own prior #141

Open JackStorrorCarter opened 7 months ago

JackStorrorCarter commented 7 months ago

I encountered a problem when trying to define my own prior distribution for the transition probabilities. I created a dictionary 'prior' which is of the same form as the default prior. However, when using this new prior I get the error 'Number of sub-lists in the list of priors must agree with the number of situations.' I think that the prior should be the same length as the number of edges, rather than the number of situations. If so then in the check_prior function, the line

if len(node_priors_list) != len(self.edge_countset):

should be changed to

if len(node_priors_list) != len(self.edges):

Of course I'm not sure of this so didn't want to change it directly - maybe it's the case that user defined priors have a different form to the default prior?

To recreate the issue, try the following st = StagedTree(df) st.calculate_AHC_transitions() prior = st.prior st = StagedTree(df) st.calculate_AHC_transitions(prior=prior)

JackStorrorCarter commented 7 months ago

Looking at the code a bit more it seems at some point the way to define the prior changed from the floret level to the edge level. Things like the default prior changed to this new edge level form, while the check_prior function remained on the floret level. With the new edge format the second error raise ValueError( "The length of each sub-list in the list of priors " "must agree with the number of edges emanating from " "its corresponding situation." ) is also not needed.

JackStorrorCarter commented 7 months ago

I've made a branch that hopefully fixes the issue. Now the check_prior function is as below. Hopefully this still does the same checks as before (checking the length of the prior is correct and that all values are positive). Perhaps a better version would check that the set of edges in the prior dict is the same as the edge set for the st?

def _check_prior(self, prior) -> None:
    if len(prior) != len(self.edges):
        raise ValueError(
            "Number of entries in the prior "
            "must agree with the number of edges."
        )

    for edge_prior in prior:
        if prior[edge_prior] < 0:
            raise ValueError("All priors must be non-negative.")
JackStorrorCarter commented 7 months ago

I've worked out the issue - I was writing the prior in the wrong way. The package works well as it is so I deleted the branch. I'll try to update the documentation soon to make it easier to understand how to set user defined priors.