Closed Tennessee-Wallaceh closed 1 year ago
Thanks for the suggestion. I do think you have a point, but it does introduce some breaking changes.
Firstly, dist.sample(key, condition)
is completely valid. In the cases where condition.shape == dist.cond_shape
, this returns, a single sample from dist
. If condition
has additional leading "batch" dimensions, then we would vectorise over these. If a user wants to sample multiple times for each conditioning variable, they can also pass sample_shape
, i.e. dist.sample(key, condition, sample_shape)
. In the simplest case, if the passed conditioning variable has no leading axes (such that condition.shape ==dist.cond_shape
), then passing sample_shape
allows repeated sampling given that conditioning variable. More broadly, the output shape is sample_shape+cond_leading_shape+dist.shape
, where cond_leading_shape
is any leading dimensions in the conditioning variable, beyond dist.cond_shape
. By changing the order of the arguments, these calls would be broken (although it would be a very easy fix for users).
I think, that, if you aren't worried about introducing a breaking change, the decision on the ordering mostly comes down to which is better out of
dist.sample(key, shape)
and dist.sample(key, condition=condition)
(yours)dist.sample(key, condition)
and dist.sample(key, sample_shape=sample_shape)
(current)I'm inclined to agree that the top one is slightly better (especially since you never will pass condition to unconditional distributions). I'll have a think if it's worth introducing a breaking change for this.
Also, looking at it, the documentation for the sample method could probably be a bit clearer, so I'll update that.
Ah ok, I hadn't appreciated the complexity with the conditional sampling.
I still suppose it comes down to whether dist.sample(key, shape)
or dist.sample(key, condition)
is more common/natural.
In my current work I'm mostly doing unconditional (VI) stuff, so that's where my preference comes from.
Alternatively, another "stricter" approach would be to define the method as Distribution.sample(self, key, *, condition=None, sample_shape=())
.
This would force the user to use kwargs after *
, so forces the use of either dist.sample(key, condition=condition)
or dist.sample(key, sample_shape=sample_shape)
.
This could be a nice way to avoid any ambiguity.
Decided to switch the argument order in https://github.com/danielward27/flowjax/pull/68. I think forcing key words shouldn't be necessary since incorrectly providing the wrong argument presumably leads to an error either way
A bit of a minor one, but to me the
Distribution.sample
method feels a bit unnatural as it forcescondition
to be passed orsample_shape
to be passed as a kwarg.How about,
Distribution.sample(key: jr.PRNGKey, sample_shape: Tuple[int] = (), condition: Optional[Array] = None)
, instead?Then the user can simply do
Distribution.sample(key, (n_samp,))
.This shouldn't be a breaking change, as any previous code would be using kwargs. What do you think?