Start adding check_shape decorations. This has lots of pros over our existing ad hoc use of tf.debugging.assert_shapes:
better checking (e.g. can check batch dimensions, input broadcasting, return types, object attributes and more)
more consistent checking (e.g. can inherit checks from abstract methods)
clearer code documentation (potentially including docstring rewriting, but see below)
clearer error messages, e.g.:
E ShapeMismatchError:
E Tensor shape mismatch.
E Function: expected_improvement.__call__
E Declared: /home/uri.granta/code/trieste/trieste/acquisition/function/function.py:215
E Argument: x
E Note: This acquisition function only supports batch sizes of one
E Expected: [N..., 1, D]
E Actual: [1, 2, 1]
However, note that:
sphinx-autoapi (which we use for our docs, probably because it's simple and quick) doesn't support docstring rewriting, which means that the argument shapes won't automatically appear in the online docs (though they will appear in interactive docstrings)
by default shape checking is disabled for anything wrapped in a tf.function (though we explicitly enable this in the unit tests) so performance impact should be limited; see below for further discussion (including impact on test runtime)
Fully backwards compatible: yes
Performance impact
Looking at integration test runtimes, the additional shape checks added here don't (currently) seem to have a significant perf impact by themselves. However:
calling set_enable_check_shapes(ShapeCheckingState.ENABLED) for all the tests does slow things down somewhat, though this is primarily due to additional the gpflow shape checks which are now enabled rather than the new checks added in this PR. For example, for test_bayesian_optimizer_with_gpr_finds_minima_of_scaled_branin
EfficientGlobalOptimization goes from 20s to 40s
AugmentedExpectedImprovement goes from 30s to 50s (even though we didn't add any more checks for that function)
Fantasizer goes from 60s to 80s.
conversely, explicitly disabling shape checking sometimes speeds things up over before, as it removes all gpflow checks; this doesn't affect EfficientGlobalOptimization or AugmentedExpectedImprovement, but for Fantasizer it speeds it up from 60s to 30s! we should definitely document and advertise this (though it's more likely to affect research than production, as the latter is more likely to be compiling everything).
Overall, the test runtime for the unit tests doesn't seem to be impacted, but the (non-slow) integration test runtime increases from around 31 minutes to 38 minutes. If this becomes an issue we can easily split the integration tests into two parallel jobs.
PR checklist
[x] The quality checks are all passing
[x] The bug case / new feature is covered by tests
[ ] Any new features are well-documented (in docstrings or notebooks)
Related issue(s)/PRs: #130
Summary
Start adding
check_shape
decorations. This has lots of pros over our existing ad hoc use oftf.debugging.assert_shapes
:However, note that:
Fully backwards compatible: yes
Performance impact
Looking at integration test runtimes, the additional shape checks added here don't (currently) seem to have a significant perf impact by themselves. However:
set_enable_check_shapes(ShapeCheckingState.ENABLED)
for all the tests does slow things down somewhat, though this is primarily due to additional the gpflow shape checks which are now enabled rather than the new checks added in this PR. For example, fortest_bayesian_optimizer_with_gpr_finds_minima_of_scaled_branin
EfficientGlobalOptimization
goes from 20s to 40sAugmentedExpectedImprovement
goes from 30s to 50s (even though we didn't add any more checks for that function)Fantasizer
goes from 60s to 80s.EfficientGlobalOptimization
orAugmentedExpectedImprovement
, but forFantasizer
it speeds it up from 60s to 30s! we should definitely document and advertise this (though it's more likely to affect research than production, as the latter is more likely to be compiling everything).Overall, the test runtime for the unit tests doesn't seem to be impacted, but the (non-slow) integration test runtime increases from around 31 minutes to 38 minutes. If this becomes an issue we can easily split the integration tests into two parallel jobs.
PR checklist