wookayin / expt

Experiment. Plot. Tabulate.
MIT License
68 stars 6 forks source link

Allow custom representative_fn in hypothesis.plot() #5

Closed wookayin closed 2 years ago

wookayin commented 2 years ago

Fixes #4

So far the representative value for drawing a plot/curve was always the mean of the individual runs, but we now support and allow more general representative values such as median or mode. To this end, a new parameter representative_fn has been added.

An accompanying change is that err_fn may return a tuple of two DataFrames that represent the range of the error band explicitly in the form of (lower, upper) range, in a case a custom representative value other than the mean is used, as well as a single DataFrame representing the radius from the mean.

wookayin commented 2 years ago

@vwxyzjn This is a preview of representative_fn for Hypothesis.plot (#4). I'll need to add more test cases but the API should be usable at the moment. Feedbacks would be appreciated!

vwxyzjn commented 2 years ago

Thank you! I tried running

def repr_fn(h: Hypothesis) -> pd.DataFrame:
    # A dummy function that manipulates the representative value ('median')
    df = cast(pd.DataFrame, h.grouped.median())
    # df['loss'] = np.asarray(df.reset_index()['step']) * -1.0
    return df

g = h.plot(x='global_step', y="charts/avg_episodic_return", rolling=50,
    n_samples=400, legend=False, err_fn=None,
    err_style="fill",
    suptitle="", ax=ax,
    representative_fn=repr_fn,)

and it worked out of the box. Quick question, though: what would the err_fn look like in the case of median? Also, what if I don't want to show the error bar and show just the median?

wookayin commented 2 years ago

As per the docstring, you can use representative_fn=lambda h: h.grouped.median() more simply.

When a custom representative value is used, it will represent the radius of the error band centered at the "mean" by default. The doc of err_fn says:

A strategy to compute the error range when err_style is band or fill. Defaults to "standard deviation.", i.e. hypothosis.grouped.std(). This function may return either:

  • (i) a single DataFrame, representing the standard error, which must have the same column and index as the hypothesis; or
  • (ii) a tuple of two DataFrames, representing the error range (lower, upper). Both DataFrames must also have the same column and index as the hypothesis.

In the case of (i), we assume that a custom representative_fn is NOT being used, but the representative value of the hypothesis is the grouped mean of the Hypothesis, i.e., hypothesis.mean().

Note that the error band (with filled regions) are shown only when err_style='fill' or err_style='band'. You can use err_style=None to display the median only, suppressing the error band.