Functions like SystemSurrogate.predict() and any __call__(x)-like function should utilize the concept of ufuncs in numpy. They should be written with only the "single-sample" use-case in mind where x.shape=(xdim,). Then they can be cast to a ufunc and automatically broadcast to any other shape (..., xdim) under the hood.
This would significantly simplify the logic and indexing in a lot of these functions.
This would also make these these functions more amenable to jax.grad auto-differentiation. Try to rewrite these methods to be as "pure" function-like as possible with no side-effects like array-mutation.
New Behavior
SystemSurrogate.predict()
and any__call__(x)
-like function should utilize the concept ofufuncs
innumpy
. They should be written with only the "single-sample" use-case in mind wherex.shape=(xdim,)
. Then they can be cast to aufunc
and automatically broadcast to any other shape(..., xdim)
under the hood.jax.grad
auto-differentiation. Try to rewrite these methods to be as "pure" function-like as possible with no side-effects like array-mutation.