patrick-kidger / diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Apache License 2.0
1.45k stars 132 forks source link

Example code in SDE section of Getting started produce incorrect result #510

Closed zd9999cs closed 1 month ago

zd9999cs commented 1 month ago

I am trying to follow the SDE example in Getting started section of diffrax's documentation, which is

import jax.random as jr
from diffrax import diffeqsolve, ControlTerm, Euler, MultiTerm, ODETerm, SaveAt, VirtualBrownianTree

t0, t1 = 1, 3
drift = lambda t, y, args: -y
diffusion = lambda t, y, args: 0.1 * t
brownian_motion = VirtualBrownianTree(t0, t1, tol=1e-3, shape=(), key=jr.PRNGKey(0))
terms = MultiTerm(ODETerm(drift), ControlTerm(diffusion, brownian_motion))
solver = Euler()
saveat = SaveAt(dense=True)

sol = diffeqsolve(terms, solver, t0, t1, dt0=0.05, y0=1.0, saveat=saveat)
print(sol.evaluate(1.1))  # DeviceArray(0.89436394)

Expected output should be 0.89436394, but my machine outputs 0.96505475. I cannot figure out what have gone wrong.

I installed diffrax through mamba and the package list in the environment is as follows:

#
# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                 conda_forge    conda-forge
_openmp_mutex             4.5                       2_gnu    conda-forge
bzip2                     1.0.8                h4bc722e_7    conda-forge
c-ares                    1.33.1               heb4867d_0    conda-forge
ca-certificates           2024.8.30            hbcca054_0    conda-forge
diffrax                   0.5.1              pyhd8ed1ab_0    conda-forge
equinox                   0.11.5             pyhd8ed1ab_0    conda-forge
importlib-metadata        8.5.0              pyha770c72_0    conda-forge
jax                       0.4.31             pyhd8ed1ab_1    conda-forge
jaxlib                    0.4.31          cpu_py312haec0345_2    conda-forge
jaxtyping                 0.2.34             pyhd8ed1ab_0    conda-forge
ld_impl_linux-64          2.43                 h712a8e2_1    conda-forge
libabseil                 20240722.0      cxx17_h5888daf_1    conda-forge
libblas                   3.9.0           24_linux64_openblas    conda-forge
libcblas                  3.9.0           24_linux64_openblas    conda-forge
libexpat                  2.6.3                h5888daf_0    conda-forge
libffi                    3.4.2                h7f98852_5    conda-forge
libgcc                    14.1.0               h77fa898_1    conda-forge
libgcc-ng                 14.1.0               h69a702a_1    conda-forge
libgfortran               14.1.0               h69a702a_1    conda-forge
libgfortran-ng            14.1.0               h69a702a_1    conda-forge
libgfortran5              14.1.0               hc5f4f2c_1    conda-forge
libgomp                   14.1.0               h77fa898_1    conda-forge
libgrpc                   1.65.5               hf5c653b_0    conda-forge
liblapack                 3.9.0           24_linux64_openblas    conda-forge
libnsl                    2.0.1                hd590300_0    conda-forge
libopenblas               0.3.27          pthreads_hac2b453_1    conda-forge
libprotobuf               5.27.5               h5b01275_2    conda-forge
libre2-11                 2023.09.01           hbbce691_3    conda-forge
libsqlite                 3.46.1               hadc24fc_0    conda-forge
libstdcxx                 14.1.0               hc0a3c3a_1    conda-forge
libstdcxx-ng              14.1.0               h4852527_1    conda-forge
libuuid                   2.38.1               h0b41bf4_0    conda-forge
libxcrypt                 4.4.36               hd590300_1    conda-forge
libzlib                   1.3.1                hb9d3cd8_2    conda-forge
lineax                    0.0.5              pyhd8ed1ab_0    conda-forge
ml_dtypes                 0.5.0           py312hf9745cd_0    conda-forge
ncurses                   6.5                  he02047a_1    conda-forge
numpy                     2.1.2           py312h58c1407_0    conda-forge
openssl                   3.3.2                hb9d3cd8_0    conda-forge
opt-einsum                3.4.0                hd8ed1ab_0    conda-forge
opt_einsum                3.4.0              pyhd8ed1ab_0    conda-forge
optimistix                0.0.7              pyhd8ed1ab_0    conda-forge
pip                       24.2               pyh8b19718_1    conda-forge
python                    3.12.7          hc5c86c4_0_cpython    conda-forge
python_abi                3.12                    5_cp312    conda-forge
re2                       2023.09.01           h77b4e00_3    conda-forge
readline                  8.2                  h8228510_1    conda-forge
scipy                     1.14.1          py312h7d485d2_0    conda-forge
setuptools                75.1.0             pyhd8ed1ab_0    conda-forge
tk                        8.6.13          noxft_h4845f30_101    conda-forge
typeguard                 2.13.3             pyhd8ed1ab_0    conda-forge
typing-extensions         4.12.2               hd8ed1ab_0    conda-forge
typing_extensions         4.12.2             pyha770c72_0    conda-forge
tzdata                    2024b                hc8b5060_0    conda-forge
wheel                     0.44.0             pyhd8ed1ab_0    conda-forge
xz                        5.2.6                h166bdaf_0    conda-forge
zipp                      3.20.2             pyhd8ed1ab_0    conda-forge
zd9999cs commented 1 month ago

I played with the random seed jr.PRNGKey() a bit, and final output varies inside 0.8 to 1.0, where my local output and the given output both reside in, thus the problem seem to be the behaviour of random number generator not identical between machines.