pymc-devs / pymc

Bayesian Modeling and Probabilistic Programming in Python
https://docs.pymc.io/
Other
8.74k stars 2.02k forks source link

Make `util.plot_gp_dist` transparency-proof #4591

Open MaPePeR opened 3 years ago

MaPePeR commented 3 years ago

Description of your problem

The plot_gp_dist function handles fill_alpha in a weird way (at least to me):

    for i, p in enumerate(percs[::-1]):
        upper = np.percentile(samples, p, axis=1)
        lower = np.percentile(samples, 100 - p, axis=1)
        color_val = colors[i]
        ax.fill_between(x, upper, lower, color=cmap(color_val), alpha=fill_alpha, **fill_kwargs)

Because the region between 1% and 99% and the region between 2% and 98% overlap, the fill_alpha parameter might not have the intended effect. The stacking of multiple transparent regions will result in the center percentiles not having transparency at all? There are 40 semi-transparent rectangles stacked in the center. The stacking of colors and color mixing might also distort the colormap towards the center.

Test-Graphs using rtcovidlive/rtlive-global with dataset from here, but heavily cropped.

What it looks like now (fill_alpha=0.6):

image

When setting fill_alpha=0.6, fill_kwargs={"edgecolor":"none"} the hard lines disappear or get softer. Still not 60% transparent: image

What i expected (modified code with fill_alpha=0.6, fill_kwargs={"edgecolor":"none"}):

image

Code replacing these lines:

    cmap = plt.get_cmap(palette)
    N = 81
    percs = np.linspace(1, 99, N)
    colors = 1 - np.abs(percs - 50) / (np.max(percs) - 50)
    samples = samples.T
    x = x.flatten()
    percs_bounds = np.percentile(samples, percs, axis=1)
    for i, p in enumerate(percs[:(N-1)//2]):
        color_val = colors[i]
        ax.fill_between(x, percs_bounds[i], percs_bounds[i + 1], facecolor=cmap(color_val), alpha=fill_alpha, edgecolor="none", **fill_kwargs)
        ax.fill_between(x, percs_bounds[-i-1], percs_bounds[-i-2], facecolor=cmap(color_val), alpha=fill_alpha, edgecolor="none", **fill_kwargs)
    if N % 2 == 0:
        ax.fill_between(x, percs_bounds[N//2 - 1], percs_bounds[N//2], facecolor=cmap(1.0), alpha=fill_alpha, edgecolor="none", **fill_kwargs)

Fixing or changing this behavior has to be considered a BC-break, I guess. But maybe I'm also missing something and this style of "progressive transparency" was intended?

Versions and main components

OriolAbril commented 3 years ago

I have not checked, but from the looks of it, it looks like this is more an issue of matplotlib colormap usage. The colormap seems to be a gradient from white to red or green, and behaves accordingly. To get the desired visual effect, the colormap should be a gradient from transparent to red or green.

This so answer seems to achieve this effect by using RGBA colors and having a linspace as alpha along the colormap, from 0 at white (thus transparent) to 1 at green.

I think that using an alpha value will override the inherent alpha of the colormap, so you may need to set fill_alpha to None or "none", in that case you may need to clip the max of the linspace to what used to be fill_alpha.

MaPePeR commented 3 years ago

The last time i tried to use a transparent colormap I ran into an issue with showing the colorbar and had to use this workaround.

Anyway: Setting a transparent colormap wouldn't solve the problem here, because the transparent regions would still overlap, reducing the effective transparency:

|   60%             fill_between for  1% - 99%
||  60% ⊕ 60%       fill_between for 30% - 70%
||| 60% ⊕ 60% ⊕ 60% fill_between for 40% - 60%
||
|
michaelosthege commented 3 years ago

Hi @MaPePeR ! Sorry I did not notice this issue until now.

@lhelleckes and I faced the same problem on another project and the discussion above has all the relevant bits to. To summarize:

  1. Create a colormap that transitions [white+transparent] → [color+opaque], for example with the helper function below.
  2. Also set plot_gp_dist(..., fill_alpha=None)

I'm not too sure that the result is quantitatively correct though. If someone wants to improve plot_gp_dist in that regard that'd be awesome!

Obviously we did not yet get around to apply this to the Rtlive.de plots.

def transparentify(cmap: colors.Colormap) -> colors.ListedColormap:
    """Creates a transparent->color version from a standard colormap.

    Stolen from https://stackoverflow.com/a/37334212/4473230

    Testing
    -------
    The following code block can be used to plot a (trasparent) colormap in a way that one
    can check if the transparency actually works. This is not trivial because the background
    is often white already.
    Check the thread under https://github.com/matplotlib/matplotlib/pull/17888#issuecomment-845253158
    for updates about automatic cmap representation in notebooks that could make this snippet obsolete.

    x = numpy.arange(256)
    fig, ax = pyplot.subplots(figsize=(12,1))
    ax.scatter(x, numpy.ones_like(x) - 0.01, s=100, c=[
        cm.Reds(v)
        for v in x
    ])
    ax.scatter(x, numpy.ones_like(x) + 0.01, s=100, c=[
        redsT(v)
        for v in x
    ])
    ax.set_ylim(0.9, 1.1)
    pyplot.show()
    """
    # Get the colormap colors
    cm_new = numpy.array(cmap(numpy.arange(cmap.N)))
    cm_new[:, 3] = numpy.linspace(0, 1, cmap.N)
    return colors.ListedColormap(cm_new)
MaPePeR commented 3 years ago

@michaelosthege I don't think the transparent color map would be quantitatively correct, because of the overlapping of fill_between regions. How big the effect is depends on the transparency, of course.

The code to improve plot_gp_dist is in the issue description. I could submit it as a pull request, but I had concerns because of the BC-break or if i missed something and the transparency is actually desired behavior, so I wanted to open a discussion first.

michaelosthege commented 3 years ago

I don't think there were any quantitative intentions about defaulting to fill_alpha=0.8. It was probably more about the looks when there's more content in the plot.

Switching to fill_alpha=None by default should be just fine, particularly since we have the major release of v4 coming up anyways. If we also include the helper function from above and maybe apply it to some popular cmaps like "Reds", "Greens", "Blues" we should have 90 % of the use cases covered.

To check if it's quantitatively correct maybe use the code from the docstring side-by-side with the fill_between-based density gradient?

MaPePeR commented 3 years ago

I botched something together, comparing the differences in the old/current plot method and the method with the fix from above:

transparency_result

The "new" plot method creates uniform transparency with the normal colormap for all percentiles and does not distort the transparency of the transparentify'd colormap. (With the old plot method the plot is 100% alpha for ~25%, even though the colormap says, it should only be full opaque for the 50% percentile)

OriolAbril commented 3 years ago

This comparison is great! Thanks!

I personally like transparentified new version the best. And the transparentify function could take an alpha value too so that the transparency ranges go from 0 to alpha instead of from 0 to 1 always. Also tagging @ColCarroll who I think will like this too

michaelosthege commented 3 years ago

Wow, that's a great comparison!

Having a slight color vision deficiency I guess I'm a little disqualified here, but the rightmost looks a little desaturated? When it comes to visibility of the line in the back however it's clearly the best.

Compared to the leftmost, all the others seem more "narrow". Does that mean that our current one exaggerates the uncertainty?

MaPePeR commented 3 years ago

The rightmost plot is desaturated, but that is probably by design, because 80% of it are less opaque than the other plots, because of the linear alpha value added to the colormap? Having the colormap increase in brightness at the same time as increasing the transparency might be at fault here, but that is exactly what the transparantify'd colormap is designed to do. (So it is desaturated quadraticly instead of lineary?)

I updated the gist to also add a non-transparent version and changed the new method to use the same percs as the old code, so they are easier to compare: Before i used percs = np.linspace(1, 99, N), now I use percs = np.concatenate([np.linspace(1,49, 40), np.linspace(51, 99, 40)]), which matches the percs from the old version.

There are still some differences though, so there might be something wrong with either or both of the codes:

transparency_result2

I think the difference is in the calcualation of the colors variable, but I'm not sure which one is right: Old: colors = (percs - np.min(percs)) / (np.max(percs) - np.min(percs)) New: colors = 1 - np.abs(percs - 50) / (np.max(percs) - 50)

Note, that for the old version the colors only contained the colors for percs = np.linspace(51, 99, 40) and for the new one the colors variable has double the length for percs = np.concatenate([np.linspace(1,49, 40), np.linspace(51, 99, 40)]).

Thinking about it the old code scales the whole colormap between 1% and 49%, while the new code scales between 1% and 50%? So I'm inclined to say the new code is more correct, but it is too close to bedtime for me to be sure about that. Maybe the problem is, that the code picks the color based on one the side of the percs-interval instead of the center?

I do think the current one might exaggerate the uncertainty when there is transparency, but also note, that the Reds colormap doesn't have linear brightness to begin with. I chose it, because it was the default. Sadly there is no colormap to white with viridis-level of brightness linearity.

Maybe I can figure out a way to generate similar graphs as in the colormap docu for the different versions of transparently rendered colormaps, but that's a project for another day.

MaPePeR commented 3 years ago

Updated the gist again. Added the full colormap to the left and right of the full opaque colormap plot for comparison. Changed the color scaling for the new plot method, because it looked wrong in comparison. There are still differences, but I'm more confident, that the new plot method is now more correct, cause the colors of reach region match the center of the full colormap.

transparency_result3

michaelosthege commented 3 years ago

Maybe we should test with cm.Greys first to get a more objective comparison of lightness. The thin lines in the center below each text are the full colormap.

Figure_1 The code is here: https://gist.github.com/michaelosthege/96ea2e8b031945f3bee1cf7ceb7c27a3

Any alpha in the fill_between or the color seems to make the band narrower than it should be.

However we should keep in mind that plot_gp_dist visualizes the percentiles! After we get our lightness and transparency right, we should consider to deprecate plot_gp_dist in favor of something like arviz.plot_band(..., kind="density"), arviz.plot_band(..., kind="density") (cc @OriolAbril).

MaPePeR commented 3 years ago

Great improvements to the code @michaelosthege! 👍

I forked it and added my attempt to calculate and plot the luminance. (Requires colorspacious): transparency_result5 1

The line for the new method using fill_alpha = 0.8 makes sense, because by applying transparency/mixing white into the colormap, the colormap is compressed, so it becomes narrower.

If you want increasing transparency the best one option might be to use the transparentify function on a colormap with a fixed color. So the color channels are fixed and only the alpha channel changes. I did this by adding cm_new[:, :3] = cm_new[-1, :3] to transparentify:

transparency_result6

This also makes the problem with the old plot method very obvious. EDIT: If you try this with a non-black colormap the result looks very wrong/bad:

transparency_result6 1

michaelosthege commented 3 years ago

Great work with the luminance! Sorry, I should have forked your gist too. Didn't have it on my radar that it was an option.

Your rightmost version clearly has a good luminance profile, but I'm worried that it doesn't match the colormap. To me one of the best options so far is still the old method with fill_alpha=None and a transparentified cmap - the 2nd from the right in this post. It matches the colormap and occludes the blue line in regions of high "density". On the other hand we're not actually plotting a density..

But maybe we shouldn't be overlayng the fill_betweens at all and instead plot them just between the percentile steps?

MaPePeR commented 3 years ago

Your rightmost version clearly has a good luminance profile, but I'm worried that it doesn't match the colormap.

Yea. I agree. It produces a very weird color that looks more gray than red to me.

But maybe we shouldn't be overlayng the fill_betweens at all and instead plot them just between the percentile steps?

That's exactly what the "new" code is doing.

For plotting multiple densities the alpha value might not be the solution at all, because it washes out the color and distorts the colormap somewhat even when there is no overlap. Maybe something like additive or multiplicative blending is better, but I have no clue if that's possible in matplotlib (I don't think so - probably not easy?) and also not sure if that would actually be better.

michaelosthege commented 3 years ago

That's exactly what the "new" code is doing.

Oh I didn't see that.

Based on the Alpha Blending formula I wrote this helper function to determine a RGBA color in a different way. It has to clip values into [1, 0] to make matplotlib happy. This destroys some "secondary" color information.

The resulting color map (bottom) is closer to the non-transparent original (top) than the transparentify method from before (middle). Figure_2

With this transparentify2 method your "new" plotting method is much closer to the original color map: Figure_1

@MaPePeR maybe you can copy the transparentify2 into your script to check its luminance profile?

(I think we may be getting closer to a quantitatively correct method 🎉)

MaPePeR commented 3 years ago

@michaelosthege That's a very good idea!!

This is the result: transparency_result7

By changing the alpha value in transparentify2 one can even use it for uniform transparency, but because of the clip it does not look good for low alpha values:

    cm_new = np.array([
        get_fg_with_alpha(c=cmap(n), alpha=0.9)
        for n in range(cmap.N)
    ])

For 0.9 it works and the resulting colormap looks identical to the original on a white background. With alpha=0.8 the center region becomes grayish, though.

michaelosthege commented 1 year ago

Just cleaning up my desktop.. This is the latest version of the script used to make some of the figures above:

cmaptests.py ```python # -*- coding: utf-8 -*- import matplotlib.pyplot as plt import matplotlib.cm import matplotlib.colors import numpy as np samples = np.linspace([0],[100]).T CM = matplotlib.cm.Reds def get_fg_with_alpha(c, alpha): """ Determines a foreground color [fg] with an alpha value [alpha] such that overlayed on [bg] it results in the color [c] as close as possible. Alpha blending: c = alpha * fg + (1 - alpha) * bg (c - (1 - alpha) * bg) / alpha = fg c = fg_component + bg_component fg_component = c - bg_component fg = fg_component / alpha """ c = np.array(c, dtype=float)[:3] bg = np.ones(3, dtype=float) if alpha == 0: fg = c else: fg = (c - (1 - alpha) * bg) / alpha fg = np.clip(fg, 0, 1) result = np.array(tuple(fg) + (alpha,)) # check = alpha * fg + (1 - alpha) * bg # checks = { # "blend matches expectation": np.allclose(check, c), # "> 0": np.all(0 <= result), # "< 1": np.all(result <= 1), # } # if not all(checks.values()): # print(f""" # {checks} # c = {c} # alpha = {alpha} # result = {result} # check = {check} # """) return result def transparentify1(cmap: matplotlib.colors.Colormap) -> matplotlib.colors.ListedColormap: """Creates a transparent->color version from a standard colormap. Stolen from https://stackoverflow.com/a/37334212/4473230 Testing ------- The following code block can be used to plot a (trasparent) colormap in a way that one can check if the transparency actually works. This is not trivial because the background is often white already. Check the thread under https://github.com/matplotlib/matplotlib/pull/17888#issuecomment-845253158 for updates about automatic cmap representation in notebooks that could make this snippet obsolete. x = numpy.arange(256) fig, ax = pyplot.subplots(figsize=(12,1)) ax.scatter(x, numpy.ones_like(x) - 0.01, s=100, c=[ cm.Reds(v) for v in x ]) ax.scatter(x, numpy.ones_like(x) + 0.01, s=100, c=[ redsT(v) for v in x ]) ax.set_ylim(0.9, 1.1) pyplot.show() """ # Get the colormap colors cm_new = np.array(cmap(np.arange(cmap.N))) cm_new[:, 3] = np.linspace(0, 1, cmap.N) return matplotlib.colors.ListedColormap(cm_new) def transparentify2(cmap: matplotlib.colors.Colormap) -> matplotlib.colors.ListedColormap: """Creates a transparent->color version from a standard colormap. Stolen from https://stackoverflow.com/a/37334212/4473230 Testing ------- The following code block can be used to plot a (trasparent) colormap in a way that one can check if the transparency actually works. This is not trivial because the background is often white already. Check the thread under https://github.com/matplotlib/matplotlib/pull/17888#issuecomment-845253158 for updates about automatic cmap representation in notebooks that could make this snippet obsolete. x = numpy.arange(256) fig, ax = pyplot.subplots(figsize=(12,1)) ax.scatter(x, numpy.ones_like(x) - 0.01, s=100, c=[ cm.Reds(v) for v in x ]) ax.scatter(x, numpy.ones_like(x) + 0.01, s=100, c=[ redsT(v) for v in x ]) ax.set_ylim(0.9, 1.1) pyplot.show() """ # Get the colormap colors cm_new = np.array([ get_fg_with_alpha(c=cmap(n), alpha=n/cmap.N) for n in range(cmap.N) ]) return matplotlib.colors.ListedColormap(cm_new) def old_plot(ax, x, samples, fill_alpha, apply, cmap=CM): if apply: cmap = apply(cmap) ax.text(np.mean(x), 101, f"old\nfill_alpha={fill_alpha}\n{apply.__name__}", ha="center", va="bottom") else: ax.text(np.mean(x), 101, f"old\nfill_alpha={fill_alpha}\n", ha="center", va="bottom") percs = np.linspace(51, 99, 40) colors = (percs - np.min(percs)) / (np.max(percs) - np.min(percs)) #samples = samples.T #x = x.flatten() for i, p in enumerate(percs[::-1]): upper = np.percentile(samples, p, axis=1) lower = np.percentile(samples, 100 - p, axis=1) color_val = colors[i] ax.fill_between(x, upper, lower, color=cmap(color_val), alpha=fill_alpha, zorder=0, edgecolor="none") ax.plot(x, [0,100], zorder=-20, color='blue') def new_plot(ax, x, samples, fill_alpha, apply, cmap=CM): if apply: cmap = apply(cmap) ax.text(np.mean(x), 101, f"new\nfill_alpha={fill_alpha}\n{apply.__name__}", ha="center", va="bottom") else: ax.text(np.mean(x), 101, f"new\nfill_alpha={fill_alpha}\n", ha="center", va="bottom") #N = 79 #percs = np.linspace(1, 99, N) percs = np.concatenate([np.linspace(1,49, 40), np.linspace(51, 99, 40)]) N = len(percs) #colors = 1 - np.abs(percs - 50) / (np.max(percs) - 50) #colors = 2 * (percs - np.min(percs)) / (np.max(percs) - np.min(percs)) colors = (percs[:-1] + percs[1:] / 2) colors = 2 * (colors - np.min(colors)) / (np.max(colors) - np.min(colors)) percs_bounds = np.percentile(samples, percs, axis=1) for i, p in enumerate(percs[:(N-1)//2]): color_val = colors[i] ax.fill_between(x, percs_bounds[i], percs_bounds[i + 1], color=cmap(color_val), alpha=fill_alpha, edgecolor="none", zorder=0) plt.fill_between(x, percs_bounds[-i-1], percs_bounds[-i-2], color=cmap(color_val), alpha=fill_alpha, edgecolor="none", zorder=0) if N % 2 == 0: ax.fill_between(x, percs_bounds[N//2 - 1], percs_bounds[N//2], color=cmap(1.0), alpha=fill_alpha, edgecolor="none", zorder=0) ax.plot(x, [0,100], zorder=-20, color='blue') def plot_cmap(ax, x, cmap=CM): for i in range(256): ax.fill_between(x, 100 - 50 * i/255, 50 * i/255, color=cmap(i/255), zorder=10, edgecolor="none") def run(): fig, ax = plt.subplots(nrows=1,ncols=1) old_plot(ax, [0, 1], samples, fill_alpha=0.8, apply=None) old_plot(ax, [1, 2], samples, fill_alpha=None, apply=None) old_plot(ax, [2, 3], samples, fill_alpha=None, apply=transparentify1) old_plot(ax, [3, 4], samples, fill_alpha=None, apply=transparentify2) new_plot(ax, [4, 5], samples, fill_alpha=None, apply=None) new_plot(ax, [5, 6], samples, fill_alpha=None, apply=transparentify1) new_plot(ax, [6, 7], samples, fill_alpha=None, apply=transparentify2) for x in np.arange(0.5, 7.5): plot_cmap(ax, [x-0.02, x+0.02]) ax.set( ylim=(0, 100), xlim=(0, 7), xticks=[], ) plt.show() def test(): from pymc3.gp.util import plot_gp_dist fig, ax = plt.subplots(nrows=1,ncols=1) plot_gp_dist( ax=ax, x=np.arange(10), samples=np.vstack([ np.random.normal(size=(1000, 10)), np.random.normal(loc=2, size=(1000, 10)), ]), palette=CM, plot_samples=False ) plt.show() def comp(): x = np.arange(256) fig, ax = plt.subplots(figsize=(12, 3)) cb = np.array([ np.arange(48) % 2, 1 - np.arange(48) % 2, ]) ax.imshow(cb * 0.25, cmap="binary", aspect="auto", extent=(-0.5, 255.5, -0.5, -1), vmax=1) ax.imshow(cb * 0.25, cmap="binary", aspect="auto", extent=(-0.5, 255.5, -1.5, -2), vmax=1) cm = matplotlib.cm.Reds cmT1 = transparentify1(cm) cmT2 = transparentify2(cm) for v in x: ax.fill_between( [v-0.5, v+0.5], [0, 0], [1, 1], color=cm(v), edgecolor="none" ) ax.fill_between( [v-0.5, v+0.5], [0, 0], [-1, -1], color=cmT1(v), edgecolor="none" ) ax.fill_between( [v-0.5, v+0.5], [-1, -1], [-2, -2], color=cmT2(v), edgecolor="none" ) ax.set( ylim=(-2, 1), xlim=(-0.5, 255.5) ) plt.show() if __name__ == "__main__": run() #test() comp() ```