danielward27 / flowjax

https://danielward27.github.io/flowjax/
MIT License
101 stars 14 forks source link

Order of arguments to Distribution.sample #62

Closed Tennessee-Wallaceh closed 1 year ago

Tennessee-Wallaceh commented 1 year ago

A bit of a minor one, but to me the Distribution.sample method feels a bit unnatural as it forces condition to be passed or sample_shape to be passed as a kwarg.

How about, Distribution.sample(key: jr.PRNGKey, sample_shape: Tuple[int] = (), condition: Optional[Array] = None), instead?

Then the user can simply do Distribution.sample(key, (n_samp,)).

This shouldn't be a breaking change, as any previous code would be using kwargs. What do you think?

danielward27 commented 1 year ago

Thanks for the suggestion. I do think you have a point, but it does introduce some breaking changes.

Firstly, dist.sample(key, condition) is completely valid. In the cases where condition.shape == dist.cond_shape, this returns, a single sample from dist. If condition has additional leading "batch" dimensions, then we would vectorise over these. If a user wants to sample multiple times for each conditioning variable, they can also pass sample_shape, i.e. dist.sample(key, condition, sample_shape). In the simplest case, if the passed conditioning variable has no leading axes (such that condition.shape ==dist.cond_shape), then passing sample_shape allows repeated sampling given that conditioning variable. More broadly, the output shape is sample_shape+cond_leading_shape+dist.shape, where cond_leading_shape is any leading dimensions in the conditioning variable, beyond dist.cond_shape. By changing the order of the arguments, these calls would be broken (although it would be a very easy fix for users).

I think, that, if you aren't worried about introducing a breaking change, the decision on the ordering mostly comes down to which is better out of

  1. Having dist.sample(key, shape) and dist.sample(key, condition=condition) (yours)
  2. Having dist.sample(key, condition) and dist.sample(key, sample_shape=sample_shape) (current)

I'm inclined to agree that the top one is slightly better (especially since you never will pass condition to unconditional distributions). I'll have a think if it's worth introducing a breaking change for this.

danielward27 commented 1 year ago

Also, looking at it, the documentation for the sample method could probably be a bit clearer, so I'll update that.

Tennessee-Wallaceh commented 1 year ago

Ah ok, I hadn't appreciated the complexity with the conditional sampling. I still suppose it comes down to whether dist.sample(key, shape) or dist.sample(key, condition) is more common/natural. In my current work I'm mostly doing unconditional (VI) stuff, so that's where my preference comes from.

Alternatively, another "stricter" approach would be to define the method as Distribution.sample(self, key, *, condition=None, sample_shape=()). This would force the user to use kwargs after *, so forces the use of either dist.sample(key, condition=condition) or dist.sample(key, sample_shape=sample_shape). This could be a nice way to avoid any ambiguity.

danielward27 commented 1 year ago

Decided to switch the argument order in https://github.com/danielward27/flowjax/pull/68. I think forcing key words shouldn't be necessary since incorrectly providing the wrong argument presumably leads to an error either way