SciML / SciMLSensitivity.jl

A component of the DiffEq ecosystem for enabling sensitivity analysis for scientific machine learning (SciML). Optimize-then-discretize, discretize-then-optimize, adjoint methods, and more for ODEs, SDEs, DDEs, DAEs, etc.
https://docs.sciml.ai/SciMLSensitivity/stable/
Other
332 stars 70 forks source link

Possible Upstream Tracker AD Issue #607

Open CharlesRSmith44 opened 4 years ago

CharlesRSmith44 commented 4 years ago

Hello,

I'm working on trying to solve Forward Backward Stochastic Differential Equations. The problem is as such:

-dY_t = (A_tY_t + B_tZ_t + C_t)dt - Z_tdW_t Y_T ~ N(m, s) T = 1

I consider the case where A_t = 1, B_t = 1, and C_t = 1.

The analytic solution is a path of Z_t and the initial position of Y_0 that satisfy the terminal condition Y_T ~ N(m,s). In the case A_t = 1, B_t = 1, and C_t = 1, the analytic solution is one where Z_t and Y_0 are both constants and the solution is explicitly given in the code.

I'm trying to use diffeqflux to generate this solution. I have two sets of parameters: the initial positions of Y and Z, and the parameters of the Neural Net that determines the path of Z_t given Z_t-1, Y_t-1, and t. I then generate many paths of the Z_t and Y_t given the data using SDEProblem and concrete_solve. I then compute the mean and variance of Y_T and compare that to the mean and variance specified. That gives me a loss. I then use sciml_train to compute gradients and update the parameters.

I attached two Julia files where I try to do just that.

Normal_SDEs_no_NN.txt

Normal_SDEs_NN.txt

The first file has no Neural Net, just initial positions, the second file includes the Neural Net. However I've noticed two problems with the code:

  1. When I run this without a Neural Net and just try to optimize the initial positions Y_0 and Z_0 everything runs fine but the solution is not the same as the analytic solution and produces higher loss consistently.
  2. When I run this with the Neural Net, I can either optimize the initial positions or the parameters of the NN but not both.

If anyone has any ideas to explain/solve 1 or 2 I would greatly appreciate it.

ChrisRackauckas commented 4 years ago

I see:

function sde_data_gen(θ)
    prob = SDEProblem(trueSDEfunc, trueNOISEfunc, θ[end-1:end], tspan, θ)
    # if θ is included in concrete_solve, it updates the NN parameters,
    #sol = Array(concrete_solve(prob, SOSRI(), θ, force_dtmin = true, saveat = trange))[1,end]
    # if θ is not included, it updates the initial positions.
    sol = Array(concrete_solve(prob, SOSRI(), force_dtmin = true, saveat = trange))[1,end]
    sol
end

You might want to take a look at https://docs.juliadiffeq.org/latest/analysis/sensitivity/#High-Level-Interface:-concrete_solve-1

concrete_solve(prob,alg,u0=prob.u0,p=prob.p,args...;
               sensealg=InterpolatingAdjoint(),
               checkpoints=sol.t,kwargs...)

The third argument is an override for u0 and the fourth is an override for p: this interface is there to make the AD overloads for du(t)/du0 and du(t)/dp easily computable from adjoint methods. Given that, I assume you wanted:

sol = Array(concrete_solve(prob, SOSRI(), θ[end-1:end], θ[1:end-2], force_dtmin = true, saveat = trange))[1,end]

Let me know if you need anything else. Cheers.

CharlesRSmith44 commented 4 years ago

Thanks for the quick response. That is exactly what I want. However, when I make that change that you suggest, the sciml_train only updates the u0 parameters not the p parameters and the solution doesn't appear to be good (at least compared to the analytic solution). Normal_SDEs_NN_v2.txt

ChrisRackauckas commented 4 years ago

This setup is a bit confusing. I think the improper use of collect is dropping gradients. Anyways, I'd stick to oop for reverse mode, and do something like:

function trueSDEfunc(u,p,t)
    true_A = Float32(-1.0)
    true_B = Float32(-1.0)
    true_C = Float32(-1.0)
    in_val = [u[2]; t]
    [u[1].*true_A .+ u[2].* true_B .+ true_C,first(model(in_val,p[1:length(p1)]))]
end

which should work.

CharlesRSmith44 commented 4 years ago

I made your adjustment and now sciml_train updates neither the parameters of the NN nor the initial positions. Any ideas? Normal_SDEs_NN_v3.txt

Thank you for all the help so far.

ChrisRackauckas commented 4 years ago

I was suggesting:

function trueSDEfunc(u,p,t)
    true_A = Float32(-1.0)
    true_B = Float32(-1.0)
    true_C = Float32(-1.0)
    in_val = [u[2]; t]
    [u[1].*true_A .+ u[2].* true_B .+ true_C,first(model(in_val,p[1:length(p1)]))]
end

function trueNOISEfunc(u,p,t)
    [u[2],zero(eltype(u)]
end

(Please paste code using ```julia instead of using text files)

CharlesRSmith44 commented 4 years ago

Okay. Thank you for all the help. I made both of those adjustments, and still sciml_train updates neither the parameters of the NN nor the initial positions. Any ideas?


using Flux, DiffEqFlux, StochasticDiffEq, Plots, DiffEqNoiseProcess, Zygote
using Statistics, DifferentialEquations, Random, DiffEqBase, Distributions

### Initial Inputs
means = Float32(3.0) #mean of distribution of YT
vars = Float32(2.0) #variance of distribution of YT
datasize = 26 #number of datapoints
tspan = (0.0f0, 1.0f0)  #datarange
trange = range(tspan[1], tspan[2],length=datasize)
dt = (tspan[2] - tspan[1])/(datasize-1)
u0 = [10.0,1.0] #initial positions of y and z

### Solution
Y0 = exp(tspan[2])*means + (exp(tspan[2])-1) * (1+sqrt(2*vars/(1-exp(-2*tspan[2]))))
Z0 = sqrt(2*vars/(1-exp(-2*tspan[2])))

## Neural Net Model
model = FastChain((x,p)->x,
                FastDense(2,5,relu),
                FastDense(5,1))
ps_m = Flux.params(model)
p1 = initial_params(model)
θ = [p1; u0] #set of parameters that are optimized over, parameters of NN and inital positions

function trueSDEfunc(u,p,t)
    true_A = Float32(-1.0)
    true_B = Float32(-1.0)
    true_C = Float32(-1.0)
    in_val = [u[2]; t]
    [u[1].*true_A .+ u[2].* true_B .+ true_C,first(model(in_val,p[1:length(p1)]))]
end

function trueNOISEfunc(u,p,t)
    [u[2],zero(eltype(u))]
end

function sde_data_gen(p)
    prob = SDEProblem(trueSDEfunc, trueNOISEfunc, p[end-1:end], tspan, p)
    sol = Array(concrete_solve(prob, SOSRI(), p[end-1:end], p[1:end-2], force_dtmin = true, saveat = trange))[1,end]
    sol
end

function loss(p)
    samples = [sde_data_gen(p) for i in 1:256]
    mean_val = mean(samples)
    variance_val = var(samples)
    totalloss = (mean_val - means)^2 + (variance_val .- vars)^2
    totalloss
end

@show loss(θ)

### Training NN

opt1 = ADAM(0.1)
res1 = DiffEqFlux.sciml_train(loss, θ, opt1, maxiters = 15)
ChrisRackauckas commented 4 years ago

Interesting... @dhairyagandhi96 can you help figure out what's going on with Tracker here?

using Flux, DiffEqFlux, StochasticDiffEq, Plots, DiffEqNoiseProcess, Zygote
using Statistics, Random, DiffEqSensitivity, DiffEqBase, Distributions

### Initial Inputs
means = Float32(3.0) #mean of distribution of YT
vars = Float32(2.0) #variance of distribution of YT
datasize = 26 #number of datapoints
tspan = (0.0f0, 1.0f0)  #datarange
trange = range(tspan[1], tspan[2],length=datasize)
dt = (tspan[2] - tspan[1])/(datasize-1)
u0 = [10.0,1.0] #initial positions of y and z

### Solution
Y0 = exp(tspan[2])*means + (exp(tspan[2])-1) * (1+sqrt(2*vars/(1-exp(-2*tspan[2]))))
Z0 = sqrt(2*vars/(1-exp(-2*tspan[2])))

## Neural Net Model
model = FastChain((x,p)->x,
                FastDense(2,5,relu),
                FastDense(5,1))
ps_m = Flux.params(model)
p1 = initial_params(model)
θ = [p1; u0] #set of parameters that are optimized over, parameters of NN and inital positions

function trueSDEfunc(u,p,t)
    true_A = Float32(-1.0)
    true_B = Float32(-1.0)
    true_C = Float32(-1.0)
    in_val = [u[2]; t]
    [u[1]*true_A + u[2]* true_B + true_C,first(model(in_val,p[1:length(p1)]))]
end

function trueNOISEfunc(u,p,t)
    [u[2],zero(eltype(u))]
end

prob = SDEProblem(trueSDEfunc, trueNOISEfunc, nothing, tspan, nothing)
function sde_data_gen(p)
    Array(concrete_solve(prob, SOSRI(), p[end-1:end], p[1:end-2], saveat = trange, sensealg=TrackerAdjoint()))[1,end]
end

function loss(p)
    samples = [sde_data_gen(p) for i in 1:256]
    mean_val = mean(samples)
    variance_val = var(samples)
    totalloss = (mean_val - means)^2 + (variance_val .- vars)^2
    totalloss
end

@show loss(θ)

### Training NN

opt1 = ADAM(0.1)
res1 = DiffEqFlux.sciml_train(loss, θ, opt1, maxiters = 15)
CharlesRSmith44 commented 4 years ago

Hello, I'm just checking in on this to see if you have thought about it and if you have ideas to solve this issue.

Thank you!

ChrisRackauckas commented 4 years ago

We're directly implementing an adjoint for stochastic differential equations and that'll both be faster and fix this issue.

CharlesRSmith44 commented 4 years ago

Awesome, thank you very much! Do you have an estimate as to when that will be ready for use?

Also, on a slightly different note, do you have any examples or packages that solve infinite horizon FBSDE or HJBs associated with infinite horizon control problems (so that the terminal condition is that the solution to the HJB PDE or the FBSDE does not explode "to quickly" as time goes to infinite rather than a more station boundary condition at a terminal time T). Thank you.

ChrisRackauckas commented 4 years ago

Awesome, thank you very much! Do you have an estimate as to when that will be ready for use?

It's one of our summer projects:

https://summerofcode.withgoogle.com/organizations/6363760870031360/?sp-page=2#5505348691034112

Probably earlier in the summer rather than later.

Also, on a slightly different note, do you have any examples or packages that solve infinite horizon FBSDE or HJBs associated with infinite horizon control problems (so that the terminal condition is that the solution to the HJB PDE or the FBSDE does not explode "to quickly" as time goes to infinite rather than a more station boundary condition at a terminal time T). Thank you.

Does this cover what you're looking for? https://arxiv.org/pdf/1908.01602.pdf This is another one of our summer projects.

CharlesRSmith44 commented 4 years ago

It's one of our summer projects:

https://summerofcode.withgoogle.com/organizations/6363760870031360/?sp-page=2#5505348691034112

Probably earlier in the summer rather than later.

Awesome, that timeline is fantastic for me. Thank you so much.

Does this cover what you're looking for? https://arxiv.org/pdf/1908.01602.pdf This is another one of our summer projects.

Hm, this is interesting and useful but not exactly what I am looking for. I'm more interested in the type of problem outlined in this paper. https://www.sciencedirect.com/science/article/pii/S0377042713002549. The main difference is the time horizon (finite vs infinite) which effects the cost function in equation (6) in the paper.

ChrisRackauckas commented 4 years ago

I wonder if that aligns with @jlperla

ChrisRackauckas commented 4 years ago

Work on SDE adjoints is being tracked here: https://github.com/SciML/DiffEqSensitivity.jl/pull/242

CharlesRSmith44 commented 4 years ago

Has progress been made on this issue. I'm resuming my work on this topic and am just checking in. Thank you!

ChrisRackauckas commented 4 years ago

The SDE adjoints exist. I think after https://github.com/SciML/StochasticDiffEq.jl/pull/332 @frankschae will start to benchmark and optimize the SDE adjoints.

CharlesRSmith44 commented 4 years ago

Thank you. This looks very promising. I'm still a little unsure how to translate this into the problem I'm looking at.

I want to write a program where

The SDE solver would be used to compute the distribution of Y_T for a given Y_0 and Z_0. Then we can computer the error as (mean(Y_T) - mu)^2 + (var(Y_T) - sigma) ^ 2. Then we use the Neural Net that takes in mu and sigma as parameters to update the initial positions of Y_0 and Z_0.

If there are examples that do something similar to this that I can look at, that would be greatly appreciated. Thank you!

ChrisRackauckas commented 4 years ago

Yes, the Deep BSDE methods? Note that we have an implementation in https://github.com/SciML/NeuralNetDiffEq.jl where the translation is described in https://arxiv.org/abs/2001.04385, so that does work already with the old TrackerAdjoint. You can look at the code here: https://github.com/SciML/NeuralNetDiffEq.jl/blob/master/src/pde_solve_ns.jl . This is one of the examples we'll be testing with the new adjoints though

ChrisRackauckas commented 4 years ago

Let's keep this thread specifically to the AD issue here, and follow usage updates in https://github.com/SciML/DiffEqFlux.jl/issues/312