arviz-devs / arviz

Exploratory analysis of Bayesian models with Python
https://python.arviz.org
Apache License 2.0
1.56k stars 387 forks source link

Adding simple, multiple and hierarchical regression plots #512

Open GWeindel opened 5 years ago

GWeindel commented 5 years ago

I have written some functions to draw regression plots from mixed models fitted in pure Stan. I wonder whether creating a branch in arviz for such plots would be interesting (guess so by seeing request #313). The basic idea would be to have the possibility of plotting linear (at least) regression, simple effects, interaction effects, with or without random effects (like sjPlot in R ).

If such a project would fit in the arviz package I could begin to code it, but I would definitely benefit from people with stronger skills.

ahartikainen commented 5 years ago

Sounds great. I think we need to think API carefully.

GWeindel commented 5 years ago

Great, if the other devs agree I can start to think about it but it surely will take some time.

ahartikainen commented 5 years ago

Sure, no problem.

Do you have some model(s) you can share that could be used as a reference?

Also, do you think observed_data is suitable location for the data or do we need some other structure?

aloctavodia commented 5 years ago

This will be a great addition to ArviZ. @GWeindel please be sure to check plot_hpd function in case you find it useful for this project.

GWeindel commented 5 years ago

I am starting to have some doubts about the feasibility. It appears to me that either one constructs a post-fit structure which then needs a lot of information about the fitted object (increasing users time and effort), or one has to master what goes in (e.g. stan or PYMC3 code) and what goes out to draw these plots. Hence I would suggest that this project should be build on top of a specialized package like bambi perhaps (https://github.com/bambinos/bambi)

ahartikainen commented 5 years ago

What if we start with regression plot done with InferenceData?

So user needs to add at creation step the following information

Scatter information:

Model information (Line plot)

Then we need to define same stuff as in ppcplot (do we take subsample etc)

ahartikainen commented 5 years ago

After that we could implement multiple regression (where each dimension is either a new axis, or something similar)

And later do hierachical structure.

Let's assume user can provide data.

ahartikainen commented 5 years ago

I was doing something simple today: linear regression...

It does get complicated fast.

We need a better interface to describe our models

Like getting the following to work is not hard

x
y_data
y_ppc

What is more or less hard

y_model

It would be great to give a function or something similar

y_model = "m*x+b"
y_model = "y = m*x+b"
y_model = "y ~ x"

and then m, x, and b are found from posterior.

Also I'm not sure, but there could still be better interface:

plot_lm("y ~ x", param=["m","b"], data=data) plot_lm("y ~ m*x+b", data=data)

Could this work with glm also? If we assume InferenceData has all the needed data, we just need to parse the function and also accept numpy functions inside the

plot_lm("exp(y) ~ m*x + log(x) + sqrt(b)", data=data)

How hard would it if we did that parsing with re?

~ splits x,y
functions have ()
others are parameters in InferenceData

Then after we have y (and possibly added pair for ppc: {"y_hat" : "y"}

def plot_lm(x, y_ppc, y_data, y_model, data, x_group=None, y_ppc_group=None, y_data_group=None, num_ppc_samples=100):
    """Plot lm

    Parameters
    ----------
    x : str or Sequence 
    y_ppc : str
    y_data : str or Sequence
    y_model : str or Sequence
    data : obj or list[obj]
        Any object that can be converted to an az.InferenceData object
        Refer to documentation of az.convert_to_dataset for details
    xgroup : str
    ygroup : str
    num_ppc_samples : int
    line_err : bool
    y_err : bool
    x_err : bool
    xscale : str
    yscale : str

    Returns
    ------
    axes
    """

    if isinstance(x, str):
        if x_group is None:
            groups = data._groups
            if hasattr(data, "observed_data"):
                groups = ["observed_data"] + [group for group in groups if group != "observed_data"]
            for group in groups:
                item = getattr(data, group)
                if x in item and x_group is None:
                    x_group = group
                elif x in item:
                    print("Warning, duplicate variable names for x, using variable from group {}".format(x_group))
        x_values = getattr(data, x_group)[x]

    if isinstance(y_ppc, str):
        if y_ppc_group is None:
            groups = data._groups
            if hasattr(data, "posterior_predictive"):
                groups = ["posterior_predictive"] + [group for group in groups if group != "posterior_predictive"]
            for group in groups:
                item = getattr(data, group)
                if y_ppc in item and y_ppc_group is None:
                    y_ppc_group = group
                elif y_ppc in item:
                    print("Warning, duplicate variable names for y_ppc, using variable from group {}".format(y_ppc_group))
        y_ppc_values = getattr(data, y_ppc_group)[y_ppc]

    if isinstance(y_data, str):
        if y_data_group is None:
            if hasattr(data, "observed_data"):
                groups = ["observed_data"] + [group for group in groups if group != "observed_data"]
            for group in groups:
                item = getattr(data, group)
                if y_data in item and y_data_group is None:
                    y_data_group = group
                elif y_data in item:
                    print("Warning, duplicate variable names for y_data, using variable from group {}".format(y_data_group))
        y_data_values = getattr(data, y_data_group)[y_data]

    fig, ax = plt.subplots(1,1, figsize=(6,4), dpi=100)

    # plot data
    ax.plot(x_values, y_data_values, marker='.', color='C3', lw=0, zorder=10)

    # plot uncertainty in y
    slicer = np.random.choice(list(range(4000)),  size=num_pp_samples, replace=False)
    y_ppc_values_ = y_ppc_values.stack(sample=("chain", "draw"))[..., slicer]
    for i in range(num_pp_samples):
        ax.plot(x_values, y_ppc_values_[..., i], marker='.', lw=0, alpha=0.1, color='C1') 

    y_model_values = y_model.stack(sample=("chain", "draw"))[... ,slicer]
    # plot uncertainty in line
    for i in range(num_pp_samples):
        ax.plot(x_values, y_model_values[..., i], lw=0.5, alpha=0.2, c='k')

    for spine in ax.spines.values():
        spine.set_visible(False)
    ax.grid(True)
    return ax

image

jankaWIS commented 3 years ago

Speaking of which, I was wondering is there currently in arviz something like regplot in seaborn? That would be great and that could also give a beginning to what has been asked here.

utkarsh-maheshwari commented 3 years ago

Just a thought. Inspite of asking for y_model, can't we calculate m abd b inside the function plot_lm? Though it will increase the complexity of the function but would reduce the complexity at the input end and make it more user-friendly.

OriolAbril commented 3 years ago

Inspite of asking for y_model, can't we calculate m abd b inside the function plot_lm?

The problem is that there is no way to know what y_model is in ArviZ (it could be possible at a higher level like in bambi, but not in ArviZ), it depends on the model, it can be a y ~ b1*x+b0 but it could have multiple covariates, higher order terms, splines...

utkarsh-maheshwari commented 3 years ago

@ahartikainen What are the assumptions we make about the data groups that should be present in infernecData passed as input??

This is an example kidiq that I am taking from posteriordb but there is no posterior predictive here, just the posterior. Can it be used as an example? image

ahartikainen commented 3 years ago

good question. I think we need to calculate the posterior predictive with python.

utkarsh-maheshwari commented 3 years ago

https://gist.github.com/utkarsh-maheshwari/8d4cd2fd84c763bf85291c3f0881d588

Here is my initial try for visualization of linear regression models inspired by plot_posterior_predictive_glm from pymc3. There are still lots of things that need to be considered though.

OriolAbril commented 3 years ago

Use

with pm.Model() as model:
    mom_iq = pm.Data("mom_iq", data["mom_iq"])

    sigma = pm.HalfNormal('sigma', sd=10)
    intercept = pm.Normal('Intercept', 0, sd=10)
    x_coeff = pm.Normal('slope', 0, sd=10)

    mean = intercept + x_coeff * mom_iq
    likelihood = pm.Normal('kid_score', mu=mean, 
                        sd=sigma, observed=data["kid_score"])

    idata = pm.sample(1000, return_inferencedata=True)

so mom_iq gets automaticaly stored as constant data. Moreover, we should definitely not convert to dataframe:

idata.posterior["Intercept"] + idata.posterior["slope"] * idata.constant_data["mom_iq"]

will work with xarray out of the box and avoid the need to loop for computation, Ari's function above has an example with stacking to get a random subsample.

Also bit of a side note, eval is a reserved word in python, it's not a good idea to use as variable name.

utkarsh-maheshwari commented 3 years ago

@OriolAbril Thank you for the suggestions. Made the suggested changes. I think, here, visualizing uncertainty in y points is insignificant because points are closely packed. (Should we include an option to show it ?)

Now there are many points that are needed to be considered for the function plot_lm:

Open to suggestions

utkarsh-maheshwari commented 3 years ago

I guess, using plot_hdi, as suggested by @aloctavodia would make it look great.

@ahartikainen, about the y_model, I think we can do it like this?

Should I open a new issue to discuss particularly plot_lm and it's visualization. Otherwise, this issue will stretch very long.

utkarsh-maheshwari commented 3 years ago

image

utkarsh-maheshwari commented 3 years ago

image

utkarsh-maheshwari commented 3 years ago

image

utkarsh-maheshwari commented 3 years ago

image

utkarsh-maheshwari commented 3 years ago

Also, do you think observed_data is suitable location for the data or do we need some other structure?

I think data could be in constant_data as well.

utkarsh-maheshwari commented 3 years ago

I have tried to modified the Ari's plot_lm function with some added features.

Achieved this image

input :

plot_lm(
    x="mom_iq", 
    y_ppc="kid_score",
    y_data="kid_score", 
    data = idata, 
    y_model = "kid_score ~ Intercept + slope * mom_iq"
)
utkarsh-maheshwari commented 3 years ago

I think we need to calculate the posterior predictive with python.

Can we use pm.sample_posterior_predective() to calculate it?

ahartikainen commented 3 years ago

I think we need to calculate the posterior predictive with python.

Can we use pm.sample_posterior_predective() to calculate it?

It depends what PPL you use for the model