patrick-kidger / diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Apache License 2.0
1.42k stars 127 forks source link

Additive SDE throws error with SRK style solvers #474

Open ParticularlyPythonicBS opened 3 months ago

ParticularlyPythonicBS commented 3 months ago

Hi, Can you help me debug why this SDE would throw errors for SRK solvers, but works and integrates fine with ERK and Milstein? Here is a simplified version of the code:

import os
import multiprocessing

os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count={}".format(
    multiprocessing.cpu_count()
)
# os.environ["JAX_TRACEBACK_FILTERING"] = "off"
# os.environ["EQX_ON_ERROR"] = "breakpoint"

import jax
import diffrax as dfx
import jax.numpy as jnp
import time

SEED = 0
KEY = jax.random.PRNGKey(SEED)

m = 0.01 # inertia
gamma = 0.1 # viscosity

amplitude = 0.42 # amplitude of the driving force
omega = 1 # frequency of the driving force
drive_period = 2 * jnp.pi / omega 

alpha = -1 # linear spring constant
beta = 1 # cubic spring constant

sigma = 0.123 # noise intensity

x0 = 1.0 # initial position
v0 = 0.0 # initial velocity
state0 = jnp.array([x0, v0])

t_min = 0.0
t_max = 2**(10) * drive_period
dt = 2 **(-8) * drive_period

def functional_duffing(t: float, state: jnp.array,
                    args: list[float])->jnp.array:
    x,v = state
    dx = v

    gamma, alpha, beta, amplitude, omega, m = args

    driving = amplitude * jnp.cos(omega * t)
    damping = gamma * v
    spring = alpha * x + beta * x ** 3
    dv = (driving - damping - spring)/m

    dstate = jnp.array([dx, dv])
    return dstate

KEY, noise_key = jax.random.split(KEY)
term = dfx.ODETerm(functional_duffing)
args = [gamma, alpha, beta, amplitude, omega, m]

brownian_noise= dfx.VirtualBrownianTree(t_min, t_max, tol=1e-3, shape=(), key=noise_key)
def noise(t, y, args):
    return jnp.array([0, sigma])

noise_term = dfx.ControlTerm(noise, brownian_noise)
terms = dfx.MultiTerm(term, noise_term)
solver = dfx.ShARK()
saveat = dfx.SaveAt(ts = jnp.arange(t_min, t_max, dt))

begin = time.time()
sol = dfx.diffeqsolve(terms, solver, t_min, t_max, dt, state0, args, saveat=saveat, max_steps= 2**20)
end = time.time()
print(f"Elapsed time: {end-begin:.2f} s")

throws this error:

ValueError: `terms` must be a PyTree of `AbstractTerms` (such as `ODETerm`), with structure diffrax._term.MultiTerm[tuple[diffrax._term.ODETerm, diffrax._term.AbstractTerm[typing.Any, diffrax._custom_types.AbstractSpaceTimeLevyArea]]]

but I am already using the multiTerm(odeTerm, controlTerm) format unless I am misunderstanding something.

Also this same simulation runs much faster in Mathematica(KloedenPlatenSchurz method), any suggestions on how to speed this up would be very helpful

Thanks for this great library!

ParticularlyPythonicBS commented 3 months ago

The error was fixed by specifying levy_area=dfx.SpaceTimeLevyArea in the brownian noise function. Leaving this up in hopes for performance improvement suggestions and possible improvements to error message.

patrick-kidger commented 3 months ago

Definitely agreed that the error message could be improved. I'd be happy to take a PR on that!

As for performance, you appear to be including the compile time as well.

ParticularlyPythonicBS commented 3 months ago

I would love to submit a PR for this! This is the error that is currently thrown: https://github.com/patrick-kidger/diffrax/blob/a37a2767b32990345f8120fd4534e068b8acb919/diffrax/_integrate.py#L1025-L1031

Should this be caught as a different error or is it better to augment this error message with a suggestion to check the levy area?

patrick-kidger commented 3 months ago

I think let's augment this error message. Tagging @lockwo as I think he may have some idea on this one.

Probably we should write out something quite verbose -- in particular, what structure we actually got! And if it's the vector field / control type that goes wrong, we should call that out explicitly. (Probably we don't need to mention Levy area anywhere, that will naturally come out of a message of the form f"expected control type {foo} but got control type {bar}"

ParticularlyPythonicBS commented 3 months ago

That sounds like a great idea, it would make that error more useful even outside the scope of SDE solvers! I look forward to lockwo's input.

lockwo commented 3 months ago

Augmenting the error message is definitely a good idea (related issues: https://github.com/patrick-kidger/diffrax/issues/461, https://github.com/patrick-kidger/diffrax/issues/446), the core issue currently is that the message isn't very informative about why the terms are failing. To that end, I think a straightforward augmentation would be to characterize the errors specifically inside the term checker (https://github.com/patrick-kidger/diffrax/blob/main/diffrax/_integrate.py#L119) and generate error messages based on that characterization, which would help people narrow down where the error is. Additionally, poorly formed shapes (preventing even drift.vf from running correctly) is a not uncommon error that results in this message (esp. for scalar/1D systems where you have some squeezing and unsqueezing), so raising specific errors based on if the eval_shape checks fail could be an option to.

For the levy area stuff, in general they are caught in the "expect but got format", but I think its worth making extra clear, since the default expected got would look not too dissimilar from the above where you have something like expected control term diffrax._term.AbstractTerm[typing.Any, diffrax._custom_types.AbstractSpaceTimeLevyArea]]] got diffrax._term.AbstractTerm[typing.Any, diffrax._custom_types.AbstractBrownianIncrement]]]. which is pretty clear but just adding a flag/specific text to say "This solver requires a levy area calculation, you need to add levy_area=diffrax.SpaceTimeLevyArea to your Brownian process" since I think that will be like the second most common error here.

Tangentially, the Levy Area docs could also probably be improved, they are printing a bunch of default attributes that aren't important and also some explanation of what a Levy Area is (and why they are integrals of space time or space time and time) would probably be beneficial.

Happy to take a crack at the above to show what I mean, or if you want to @ParticularlyPythonicBS also works.

ParticularlyPythonicBS commented 2 months ago

@lockwo you seem to have expertise with the library that will let you do this much faster than I could. So I would be happy to just follow along. If you are otherwise occupied, I am happy to take an attempt at it though.

lockwo commented 2 months ago

My sort of idea: https://github.com/patrick-kidger/diffrax/pull/478