SciML / SciMLExpectations.jl

Fast uncertainty quantification for scientific machine learning (SciML) and differential equations
https://docs.sciml.ai/SciMLExpectations/stable/
Other
65 stars 20 forks source link

Gradient by zygote of scalar expectation in batch mode #49

Open ArnoStrouwen opened 3 years ago

ArnoStrouwen commented 3 years ago
using ForwardDiff
using Zygote
using OrdinaryDiffEq
using DiffEqUncertainty
using DiffEqSensitivity
using Distributions
using Cubature

function f!(du,u,p,t)
    du[1] = p[1]*u[1] - p[2]*u[1]*u[2] #prey
    du[2] = -p[3]*u[2] + p[4]*u[1]*u[2] #predator
end

tspan = (0.0,10.0)
u0 = [1.0;1.0]
p = [1.5,1.0,3.0,1.0]
prob = ODEProblem(f!,u0,tspan,p,sensealg=InterpolatingAdjoint())
g(sol) = sol[1,end]

p1_3 = [1.5,3.0]
function testf!(p1_3)
    p_dist = [p1_3[1],1.0,p1_3[2],truncated(Normal(1.0,.1),.6, 1.4)]
    u0_dist = [1.0, Uniform(0.8, 1.1)]
    expectation(g, prob, u0_dist, p_dist, Koopman(), Tsit5(), quadalg=CubatureJLp(),batch=32)[1]
end
testf!(p1_3)
ForwardDiff.gradient(testf!,p1_3)
Zygote.gradient(testf!,p1_3)
agerlach commented 3 years ago

I am able to run this but it looks like the gradients don't match. They do match for batch=0. Is that what you observed? From slack I got the impression that zygote was erroring.

ArnoStrouwen commented 3 years ago
ERROR: MethodError: no method matching RecursiveArrayTools.VectorOfArray(::Tuple{Float64})
status `~/Dropbox/julia/small projects/dynamic experimental design tutorial/Project.toml`
  [667455a9] Cubature v1.5.1
  [aae7a2af] DiffEqFlux v1.34.1
  [41bf760c] DiffEqSensitivity v6.43.1
  [ef61062a] DiffEqUncertainty v1.8.0
  [31c24e10] Distributions v0.24.15
  [6a86dc24] FiniteDiff v2.8.0
  [587475ba] Flux v0.11.6
  [f6369f11] ForwardDiff v0.10.17
  [429524aa] Optim v1.3.0
  [1dea7af3] OrdinaryDiffEq v5.52.2
  [91a5bcdd] Plots v1.11.0
  [37e2e3b7] ReverseDiff v1.7.0
  [e88e6eb3] Zygote v0.6.6
ChrisRackauckas commented 3 years ago

Interesting. Doesn't look the same as https://github.com/SciML/Quadrature.jl/issues/49 . I wonder if we can isolate this down to just Quadrature.jl

agerlach commented 3 years ago

@ArnoStrouwen Is this in a new repl? Something doesn't seem right b/c RecursiveArrayTools isn't in Quadrature or DiffEqUncertainty. I just searched the repos to makes sure

This runs for me w/

  [667455a9] Cubature v1.5.1
  [41bf760c] DiffEqSensitivity v6.43.1
  [ef61062a] DiffEqUncertainty v1.8.0
  [31c24e10] Distributions v0.24.15
  [6a86dc24] FiniteDiff v2.8.0
  [f6369f11] ForwardDiff v0.10.17
  [1dea7af3] OrdinaryDiffEq v5.52.2
  [e88e6eb3] Zygote v0.6.6

But I get

ForwardDiff: 1.8264648587604257, 0.5292345598328614 Zygote: 7.146998823274198 , 1.8884006857029283 FiniteDiff: 1.6546890304376016, 3.702681873199691

ArnoStrouwen commented 3 years ago

It might be something Visual Studio Code loads, since I get the same results as you when I run from terminal.