secondmind-labs / trieste

A Bayesian optimization toolbox built on TensorFlow
Apache License 2.0
217 stars 42 forks source link

Fix Batch Reparam Sampler support for varying batch sizes #748

Closed uri-granta closed 1 year ago

uri-granta commented 1 year ago

The recent commit to make reparameterization samplers more XLA friendly (#718) broke one of notebooks. Turns out that Async BO calls BatchReparametrizationSampler with different sized batches (due to pending points) meaning ._eps needs to remain dynamically shaped in this case. Interestingly the JIT compilation test still passes.

uri-granta commented 1 year ago

It might be a better idea to allow the shape to be specified / or have a flag like "relax_shapes" which does this.

Can you explain what you mean? Having the user specify the shape when creating the BatchReparametrizationSampler (or when calling sample for the first time) feels like a major usability issue? And I though relax_shapes/reduce_retracing is about the arguments of compiled functions, not variable shapes?

Dynamic shape can have performance implications.

True, though worth remembering that it was already dynamically shaped on the previous release, so it's not a degradation.

sam-willis commented 1 year ago

Can you explain what you mean?

A few options:

Having the user specify the shape when creating the BatchReparametrizationSampler (or when calling sample for the first time) feels like a major usability issue?

I don't think either of these are usability problems really, because the default can always be no change. We already pass S, but you could default L and B to None and construct the new tensorshape from that. Equally, you could default the flag to be whichever way you wanted.

And I thought relax_shapes/reduce_retracing is about the arguments of compiled functions, not variable shapes?

Not sure what this means.

uri-granta commented 1 year ago

Can you explain what you mean?

A few options:

* pass in the shape of L, S and B

* pass a flag which enables fixed_shape/relaxed_shapes, i.e. either do new or old behaviour behind a flag, so if shape is fixed it expects the same shape to always be used (but will infer it based on first use)

Ah, ok. I think adding a boolean option to disable the relaxed shapes does make sense. I'll do that in a subsequent PR as this one is blocking the docs build running successfully.