JuliaPlots / StatsPlots.jl

Statistical plotting recipes for Plots.jl
Other
440 stars 90 forks source link

Alternative Corner Plot? #391

Open farr opened 4 years ago

farr commented 4 years ago

I'm afraid I really don't like the format of cornerplot when plotting the outputs of MCMC simulations. Normally I don't care so much about the correlation coefficient of my samples (and I certainly don't want to color the scatterplots by correlation coefficient); ideally, I would like to show both some estimate of the 2D density in the off-diagonal grid squares and also show the samples (as a sanity check of the density estimate).

The code below implements something like what I prefer---is there a way to slip this recipe into StatsPlots.jl, perhaps with a name that is better than alternativecornerplot ;)? If others are interested, I'd be happy to file a formal pull request---my personal preference would be just to replace the existing cornerplot recipe, but of course I understand that there may be reluctance to make such a dramatic change. Any suggestions for good names?

The code:

using KernelDensity
using Plots
using StatsPlots

@userplot CornerPlot
@recipe function f(cp::CornerPlot)
    m = cp.args[1]

    nl = get(plotattributes, :levels, 10)
    N = size(m, 1)

    labs = pop!(plotattributes, :label, ["x$i" for i=1:N])
    if labs!=[""] && length(labs)!=N
        error("Number of labels not identical to number of datasets")
    end

    legend := false
    layout := (N,N)

    for i in 1:N
        # Do the diagonals
        @series begin
            subplot := i + (i-1)*N
            seriestype := :density
            xlims := (minimum(m[i]), maximum(m[i]))
            ylims := (0, Inf)
            xguide := labs[i]
            x := m[i]            
        end
    end

    for i in 1:N
        for j in 1:(i-1)
            # Do the kdeplots
            k = kde((m[j], m[i]))
            dv = vec(k.density)
            inds = reverse(sortperm(dv))
            cd = cumsum(dv[inds])
            C = cd[end]

            levels = []
            for i in 1:nl
                f = i/(nl+1)
                cf = f*C
                ind = searchsortedfirst(cd, cf)
                push!(levels, dv[inds[ind]])
            end
            levels = reverse(levels)

            @series begin
                seriestype := :contour
                subplot := (i-1)*N + j
                seriescolor --> :viridis
                x := k.x
                y := k.x
                z := permutedims(k.density)
                levels := levels
                xlims := (minimum(m[j]), maximum(m[j]))
                ylims := (minimum(m[i]), maximum(m[i]))
                xguide := labs[j]
                yguide := labs[i]
                k.x, k.y, permutedims(k.density)
            end
        end
    end

    for i in 1:N
        for j in (i+1):N
            # Do the scatterplots
            @series begin
                seriestype := scatter
                subplot := (i-1)*N + j
                x := m[j]
                y := m[i]
                markersize --> 0.1
                xlims := (minimum(m[j]), maximum(m[j]))
                ylims := (minimum(m[i]), maximum(m[i]))
                xguide := labs[j]
                yguide := labs[i]
                m[j], m[i]
            end
        end
    end
end

Usage: obtain an MCMC chain, and then

@df trace cornerplot([:a, :b, :c], label=[L"a", L"b", L"c"], size=(1000, 1000))

will produce

corner

PythonNut commented 3 years ago

I am coming from seaborn and am also accustomed to using these sorts of plots (via PairGrid).

scheidan commented 3 years ago

That's very neat, thanks! With the existing corner plot I'm also irritated by the line and the coloring.

Maybe single corner recipe with more options instead of multiple alternative recipes is the better the way to go.

A 2d-histogram could be an alternative for the density, but that's only a personal preference.

PythonNut commented 3 years ago

One of the nice things about seaborn's PairGrid is that you can customize the upper triangle, diagonal, and lower triangle plots to be whatever you like (e.g. histogram, scatter, kde, etc.).

sethaxen commented 3 years ago

Both ArviZPlots.jl and CornerPlot.jl do something similar to this. It probably would be good to offer this alternative layout using keyword arguments.