Open nstarman opened 3 weeks ago
I think this'd be really cool!
So as you've highlighted, Quax is the way to go about implementing something like this. The way this would work is to provide overloads for all primitives that diffeqsolve
uses, and then call quax.quaxify(diffrax.diffeqsolve)(y0=Interval(...), ...)
.
In particular the point of a library like Quax is that this shouldn't require changing Diffrax at all.
Providing all of those overloads is probably fairly ambitious (and might stress-test just how load-bearing Quax really is 😅), but if you do something like that I'd love to see it!
As a simple example:
class ExactMeasurement(quax.ArrayValue):
array: ArrayLike = eqx.field(converter=jnp.asarray)
def aval(self):
shape = jnp.shape(self.array)
dtype = jnp.result_type(self.array)
return jax.core.ShapedArray(shape, dtype)
def materialise(self):
msg = "Refusing to materialise."
raise ValueError(msg)
@staticmethod
def default(
primitive: jax.core.Primitive,
values: Sequence[ArrayLike | quax.Value],
params: dict,
):
raw_values: list[ArrayLike] = []
for value in values:
if eqx.is_array_like(value):
raw_values.append(value)
elif isinstance(value, ExactMeasurement):
raw_values.append(value.array)
elif isinstance(value, quax.Value):
raise NotImplementedError
else:
raise AssertionError
out = primitive.bind(*raw_values, **params)
return (
[ExactMeasurement(x) for x in out]
if primitive.multiple_results
else ExactMeasurement(out)
)
vector_field = lambda t, y, args: -y
term = ODETerm(vector_field)
solver = Dopri5()
saveat = SaveAt(ts=[0.0, 1.0, 2.0, 3.0])
stepsize_controller = PIDController(rtol=1e-5, atol=1e-5)
sol = quax.quaxify(diffeqsolve)(
term,
solver,
t0=0,
t1=3,
dt0=0.1,
y0=ExactMeasurement(1.0),
saveat=saveat,
stepsize_controller=stepsize_controller,
)
print(sol.ts) # ExactMeasurement(array=f32[4])
print(sol.ys) # ExactMeasurement(array=f32[4])
Obviously only having the default
rule is too permissive since sol.ts
is now an ExactMeasurement
, but this shows it's possible.
It would be amazing to be able to specify uncertainties on ICs and have that propagate through to the solutions. I know this is a very difficult problem in general. Off the top of my head some challenges are:
Thankfully I think at least point 2 has a workable solution. @patrick-kidger, you've written
quax
to allow for array-ish objects in JAX. My suggestion would be to make adiffeqsolve(y0=)
acceptquax
classes that handle the distribution and its propagation.Point 1 still remains hard, but there's a still-useful starting point. The simplest "uncertainty" to support isn't even Gaussian but a simple lower and upper bound interval. That would be a good proof of concept but still useful! I know that the same result could be accomplished by doing
diffeqsolve
twice, but a) the unified API would be a convenience and b) we could hopefully subsequently implement Gaussian and more complex distributions.To use the opening example from https://docs.kidger.site/diffrax/usage/getting-started/
As a related note, having
quax
classes would also enable a nice bundling of arrays ofy0
into aMonteCarloMeasurement
approximation of an uncertainty distribution: