SciML / StochasticDiffEq.jl

Solvers for stochastic differential equations which connect with the scientific machine learning (SciML) ecosystem
Other
237 stars 65 forks source link

Strange results when using seed kwarg in solve function #532

Closed stochasticguy closed 10 months ago

stochasticguy commented 1 year ago

Hi everyone,

First of all I wanted to congratulate you for this amazing package.

I'm having issues with the solution of a Stochastic Differential Equation (SDE) when passing the seed kwarg to the solve method.

When I don't pass the seed kwarg, the results I get are reasonable. On the other hand, when I pass the seed kwarg, the results are strange (i.e I get results that are either zero or far away from the expected result). I tried out different random number generators and the results are the same in the sense that are far from the expected ones.

I know that this problem is stochastic. Therefore, I don't expect the results to be the same when I use different seeds, rngs, etc. But I'm surprised to see how off they are.

Below there's an example.

using DifferentialEquations, Parameters, RandomNumbers, Statistics

X0 = 90.0
K  = 80.0
r  = 0.0
σ  = 0.2
p  = (m = r, s = σ)

function f(u, p, t)
    @unpack m = p
    m * u
end

function g(u, p, t)
    @unpack s = p
    s * u
end

# Instantiate random number generators
mtw_rng = MersenneTwisters.MT19937()
pcg_rng = PCG.PCGStateOneseq()
pil_rng = Random123.Philox4x()
xor_rng = Xorshifts.Xoroshiro128Plus()

# Create noise processes with the previously created random number generators
W_mtw = WienerProcess(0.0, 1.0, 1.0; rng=mtw_rng)
W_pcg = WienerProcess(0.0, 1.0, 1.0; rng=pcg_rng)
W_pil = WienerProcess(0.0, 1.0, 1.0; rng=pil_rng)
W_xor = WienerProcess(0.0, 1.0, 1.0; rng=xor_rng)

# Create sde problems
sde     = SDEProblem{false}(f, g, X0, (0.0, 1.0), p)
sde_mtw = SDEProblem{false}(f, g, X0, (0.0, 1.0), p; noise=W_mtw)
sde_pcg = SDEProblem{false}(f, g, X0, (0.0, 1.0), p; noise=W_pcg)
sde_pil = SDEProblem{false}(f, g, X0, (0.0, 1.0), p; noise=W_pil)
sde_xor = SDEProblem{false}(f, g, X0, (0.0, 1.0), p; noise=W_xor)

function relu(X, p)
    @unpack K, τ = p
    return max(X(τ) - K, 0.0)
end

# Create Ensemble problems using Relu function as output
ens     = EnsembleProblem(sde, output_func = (sol, i) -> (relu(sol, (K = K, τ = 1.0)), false))
ens_mtw = EnsembleProblem(sde_mtw, output_func = (sol, i) -> (relu(sol, (K = K, τ = 1.0)), false))
ens_pcg = EnsembleProblem(sde_pcg, output_func = (sol, i) -> (relu(sol, (K = K, τ = 1.0)), false))
ens_pil = EnsembleProblem(sde_pil, output_func = (sol, i) -> (relu(sol, (K = K, τ = 1.0)), false))
ens_xor = EnsembleProblem(sde_xor, output_func = (sol, i) -> (relu(sol, (K = K, τ = 1.0)), false))

# Set seed and number of trajectories
seed = 1234
n = 100_000

# When seed kwarg is passed to the solve method the results are way off 
mean(solve(ens; trajectories=n).u) # The value for this is: 12.96215350985002
mean(solve(ens; trajectories=n, seed=seed).u) # The value for this is: 0.0
mean(solve(ens_mtw; trajectories=n, seed=seed).u) # The value for this is: 0.0
mean(solve(ens_pcg; trajectories=n, seed=seed).u) # The value for this is: 0.0
mean(solve(ens_xor; trajectories=n, seed=seed).u) # The value for this is: 7.115400747775788

# When seed kwarg is NOT passed to the solve method the results are close to the expected
mean(solve(ens; trajectories=n).u) # The value for this is: 12.861788416616768
mean(solve(ens_mtw; trajectories=n).u) # The value for this is: 12.932072289191881
mean(solve(ens_pcg; trajectories=n).u) # The value for this is: 12.879136140283649
mean(solve(ens_xor; trajectories=n).u) # The value for this is: 12.831574338732501
ChrisRackauckas commented 10 months ago

Hello, sorry this email got buried.

It actually has a rather simple explanation: you're using the same seed for all solves 😅. Note that for an ensemble problem, a keyword argument is sent to all of the constituent solves. This is so that solve(enprob, alg; trajectories = N, abstol=1e-8) is a nice way to set the tolerance for all solutions, among other solver characteristics. That behavior is documented on the page https://docs.sciml.ai/DiffEqDocs/stable/features/ensemble/, and so by setting seed as a keyword argument in the solve you're setting that seed for all of the solves and thus they are all using the same random numbers. If you plot the ensemble solution this is clear.

The way to do this correctly then is to use the prob_func to say how specific problems should be different. For example:

function prob_func(prob, i, repeat)
    remake(prob, seed = i)
end

would make the seed equal to i on the ith trajectory. You may want to change that up a bit of course, taking random seeds etc., but that shows how to force it to be different per trajectory. Then you just do:

ens_pil = EnsembleProblem(sde_pil, prob_func = prob_func, output_func = (sol, i) -> (relu(sol, (K = K, τ = 1.0)), false))

and you see the convergence restored.

So in total, StochasticDiffEq.jl lets you set the seed, and if you set the same random seed for a whole Monte Carlo simulation you won't get the values you expect, and instead need to make sure different trajectories get different random numbers by setting a seed per problem rather than globally to all problems.

Hopefully that is clear.