secondmind-labs / trieste

A Bayesian optimization toolbox built on TensorFlow
Apache License 2.0
219 stars 42 forks source link

Start adding some check_shape decorations #770

Closed uri-granta closed 1 year ago

uri-granta commented 1 year ago

Related issue(s)/PRs: #130

Summary

Start adding check_shape decorations. This has lots of pros over our existing ad hoc use of tf.debugging.assert_shapes:

  1. better checking (e.g. can check batch dimensions, input broadcasting, return types, object attributes and more)
  2. more consistent checking (e.g. can inherit checks from abstract methods)
  3. clearer code documentation (potentially including docstring rewriting, but see below)
  4. clearer error messages, e.g.:
E             ShapeMismatchError: 
E             Tensor shape mismatch.
E               Function: expected_improvement.__call__
E                 Declared: /home/uri.granta/code/trieste/trieste/acquisition/function/function.py:215
E                 Argument: x
E                   Note:     This acquisition function only supports batch sizes of one
E                   Expected: [N..., 1, D]
E                   Actual:   [1, 2, 1]

However, note that:

Fully backwards compatible: yes

Performance impact

Looking at integration test runtimes, the additional shape checks added here don't (currently) seem to have a significant perf impact by themselves. However:

Overall, the test runtime for the unit tests doesn't seem to be impacted, but the (non-slow) integration test runtime increases from around 31 minutes to 38 minutes. If this becomes an issue we can easily split the integration tests into two parallel jobs.

PR checklist