SciML / DiffEqFlux.jl

Pre-built implicit layer architectures with O(1) backprop, GPUs, and stiff+non-stiff DE solvers, demonstrating scientific machine learning (SciML) and physics-informed machine learning methods
https://docs.sciml.ai/DiffEqFlux/stable
MIT License
865 stars 153 forks source link

Correct way to index into parameter vector? #281

Closed metanoid closed 4 years ago

metanoid commented 4 years ago

I'm trying to get sciml_train to build an estimator for u0 and p, given a known functional form of the ODE.

Inside the predict function, I index my parameter vector theta into two parts:

Oddly enough, the first of these indexing operations works, the second does not - I get an error: DimensionMismatch("array could not be broadcast to match destination")

Is there a better way to structure the code so that indexing into the parameter vector will not fail?

Reproducible example:

using DifferentialEquations
using Flux, DiffEqFlux
using DiffEqSensitivity
using Plots
using Dates
using Statistics
using Optim

function ModelJ!(du, u, p, t)
    A,B,C = u

    du[1] = -p[1]*(A - p[2]*B*C)*0.01
    du[2] = -p[3]*(B - p[4]*C*A)*0.01
    du[3] = -p[5]*(C - p[6]*A*B)*0.01
end

u0 = rand(3)
p = rand(6)
tspan = (0.0, 100.0)

prob = ODEProblem(ModelJ!, u0, tspan, p)
sol = solve(prob, saveat = 1.0)
plot(sol, legend = :outertopright)

noise!(du,u,p,t) = 0.01 .* u

function generate_fake_data()
    # generate a random starting position
    u0 = rand(3)
    # generate a random set of parameter values
    p = rand(6)
    tspan = (0.0, 100.0)
    # make the problem
    my_prob = SDEProblem(ModelJ!, noise!, u0, tspan, p)
    sol = solve(my_prob, saveat = 1.0)
    sol_data = Array(sol)
    return sol_data
end
test = generate_fake_data()

function generate_fake_dataset(n = 5)
    streams = []
    for i in 1:n
        fake_data = generate_fake_data()
        push!(streams, fake_data)
    end
    return streams
end

fake_set = generate_fake_dataset(3)

# okay, now we want to build a model to work out the functional form that generated this data, given we know the functional form of the ODE

# we need 3 pieces:
# estimate the starting u0
# estimate the fixed params p
# predict the extrapolation given these

u0Net = RNN(length(u0), length(u0), tanh)
p_u0, re_u0 = Flux.destructure(u0Net)
pNet = RNN(length(u0), length(p), tanh)
p_p, re_p = Flux.destructure(pNet)

p_u0_start = 1
p_p_start = length(p_u0) + 1

θ = [p_u0; p_p]
d = fake_set[1]

function predict(θ, d)
    #step 1 - estimate u0 from d
    # extract the u0-estimating parameters from θ
    p_u0_m = θ[range(p_u0_start, length = length(p_u0))] # this line works
    # build the u0 estimator
    u0_estimator = re_u0(p_u0_m)
    # reverse d
    d_rev = [d[:,j] for j in size(d,2):-1:1]
    # now estimate u0
    u0_seq = u0_estimator.(d_rev)
    u0 = u0_seq[end]

    #step 2 - estimate p from d
    # extract the p-estimating parameters from θ
    p_p_m = θ[range(p_p_start, length = length(p_p))] # this line fails
    # build the p estimator
    p_estimator = re_p(p_p_m)
    # now estimate p
    p_seq = p_estimator.(d_rev)
    p = p_seq[end]

    # step 3 - solve the ode with this u0 and this p to get the prediction
    prob = ODEProblem(ModelJ!, u0, tspan, p)
    full_sol = solve(prob, saveat = 1.0)
    pred = Array(full_sol)
    return pred, full_sol
end

pred, sol = predict(θ, d)

T = 50
function loss_batched(θ, batch)
    loss = 0.0
    # i = 0
    for d in batch
        # println(i)
        # i += 1
        # use the first T datapoints
        d_first_T = d[:,1:T]
        pred, full_sol = predict(θ, d_first_T)

        # compare the predictions for the first T days to the actual data for those days

        pred_days = pred
        actual_days = d[:,1:size(pred_days, 2)]
        loss += sum(abs2,pred_days .- actual_days)/size(pred_days,2)
        # note - the above will need to be rewritten to account for missing data
    end
    return loss
end

batch = fake_set[1:3]
l = loss_batched(θ, batch)

# now to try to train the model to provide good guesses for u0 and p

test_ode = DiffEqFlux.sciml_train(loss_batched, θ, ADAM(0.01), [[batch]], progress = false)

When running the final sciml_train line, the error appears on line 84 but not on line 73.

Error message in full:

ERROR: DimensionMismatch("array could not be broadcast to match destination")
Stacktrace:
 [1] check_broadcast_shape at .\broadcast.jl:509 [inlined]
 [2] check_broadcast_axes at C:\Users\username\.julia\packages\DiffEqBase\3jJQt\src\diffeqfastbc.jl:26 [inlined]
 [3] check_broadcast_axes at .\broadcast.jl:516 [inlined]
 [4] instantiate at .\broadcast.jl:259 [inlined]
 [5] materialize! at .\broadcast.jl:823 [inlined]
 [6] (::Zygote.var"#1040#1042"{Array{Float32,1},Tuple{UnitRange{Int64}}})(::Array{Any,1}) at C:\Users\username\.julia\packages\Zygote\YeCEW\src\lib\array.jl:45
 [7] (::Zygote.var"#2739#back#1036"{Zygote.var"#1040#1042"{Array{Float32,1},Tuple{UnitRange{Int64}}}})(::Array{Any,1}) at C:\Users\username\.julia\packages\ZygoteRules\6nssF\src\adjoint.jl:49
 [8] predict at .\untitled-fc4b4f9e8f268e3062377bf4e032f7a1:84 [inlined]
 [9] (::typeof(∂(predict)))(::Tuple{Array{Float64,2},Nothing}) at C:\Users\username\.julia\packages\Zygote\YeCEW\src\compiler\interface2.jl:0
 [10] loss_batched at .\untitled-fc4b4f9e8f268e3062377bf4e032f7a1:110 [inlined]
 [11] (::typeof(∂(loss_batched)))(::Float64) at C:\Users\username\.julia\packages\Zygote\YeCEW\src\compiler\interface2.jl:0
 [12] (::Zygote.var"#174#175"{typeof(∂(loss_batched)),Tuple{Tuple{Nothing},Int64}})(::Float64) at C:\Users\username\.julia\packages\Zygote\YeCEW\src\lib\lib.jl:182
 [13] (::Zygote.var"#347#back#176"{Zygote.var"#174#175"{typeof(∂(loss_batched)),Tuple{Tuple{Nothing},Int64}}})(::Float64) at C:\Users\username\.julia\packages\ZygoteRules\6nssF\src\adjoint.jl:49
 [14] #25 at C:\Users\username\.julia\packages\DiffEqFlux\GpzmI\src\train.jl:99 [inlined]
 [15] (::typeof(∂(λ)))(::Float64) at C:\Users\username\.julia\packages\Zygote\YeCEW\src\compiler\interface2.jl:0
 [16] (::Zygote.var"#49#50"{Zygote.Params,Zygote.Context,typeof(∂(λ))})(::Float64) at C:\Users\username\.julia\packages\Zygote\YeCEW\src\compiler\interface.jl:179
 [17] gradient(::Function, ::Zygote.Params) at C:\Users\username\.julia\packages\Zygote\YeCEW\src\compiler\interface.jl:55
 [18] sciml_train(::Function, ::Array{Float32,1}, ::ADAM, ::Array{Array{Array{Any,1},1},1}; cb::Function, maxiters::Int64, progress::Bool, save_best::Bool) at C:\Users\username\.julia\packages\DiffEqFlux\GpzmI\src\train.jl:98
 [19] top-level scope at none:0

Versions: Julia 1.4.2

  [aae7a2af] DiffEqFlux v1.12.0 #master (https://github.com/SciML/DiffEqFlux.jl.git)
  [41bf760c] DiffEqSensitivity v6.19.1 #master (https://github.com/SciML/DiffEqSensitivity.jl.git)
  [0c46a032] DifferentialEquations v6.14.0
  [587475ba] Flux v0.10.5 #master (https://github.com/FluxML/Flux.jl.git)
  [429524aa] Optim v0.20.6
  [91a5bcdd] Plots v1.3.7
metanoid commented 4 years ago

I tried to put together a smaller MWE, but this code (adapted from the tutorials) does not show the same error. I don't understand why not.

using DiffEqFlux, Flux, Optim, OrdinaryDiffEq, Plots

u0 = [0.1f0, 1.1f0]
tspan = (0.0f0, 25.0f0)
tsteps = 0.0f0:1.0:25.0f0

model_univ = FastChain(FastDense(2, 16, tanh),
                       FastDense(16, 16, tanh),
                       FastDense(16, 1))

# The model weights are destructured into a vector of parameters
p_model = initial_params(model_univ)
n_weights = length(p_model)

# Parameters of the second equation (linear dynamics)
p_system = Float32[0.5, -0.5]

p_all = [p_model; p_system]
θ = Float32[u0; p_all]

function dudt_univ!(du, u, p, t)
    # Destructure the parameters
    model_weights = p[1:n_weights]
    α = p[end - 1]
    β = p[end]

    # The neural network outputs a control taken by the system
    # The system then produces an output
    model_control, system_output = u

    # Dynamics of the control and system
    dmodel_control = model_univ(u, model_weights)[1]
    dsystem_output = α*system_output + β*model_control

    # Update in place
    du[1] = dmodel_control
    du[2] = dsystem_output
end

prob_univ = ODEProblem(dudt_univ!, u0, tspan, p_all)
sol_univ = solve(prob_univ, Tsit5(),abstol = 1e-8, reltol = 1e-6)

# let's say we want to estimate u0 and p from the fixed data of "[1]"
u0_est = FastChain(FastDense(1,2))
p_est = FastChain(FastDense(1,n_weights))
a_params = initial_params(u0_est)
b_params = initial_params(p_est)

θ = [a_params; b_params]

function predict_univ(θ)
    a = θ[1:length(a_params)]
    b = θ[(length(a_params) + 1):end]
    u0 = u0_est([1.0f0], a)
    p = p_est([1.0f0], b)
  return Array(solve(prob_univ, Tsit5(), u0=u0, p=p,
                              saveat = tsteps))
end

loss_univ(θ) = sum(abs2, predict_univ(θ)[2,:] .- 1)
l = loss_univ(θ)

list_plots = []
iter = 0
callback = function (θ, l)
  global list_plots, iter

  if iter == 0
    list_plots = []
  end
  iter += 1

  println(l)

  plt = plot(predict_univ(θ)', ylim = (0, 6))
  push!(list_plots, plt)
  display(plt)
  return false
end

result_univ = DiffEqFlux.sciml_train(loss_univ, θ,
                                     BFGS(initial_stepnorm = 0.01),
                                     cb = callback)
ChrisRackauckas commented 4 years ago

It could just be upstream RNN issues like https://github.com/FluxML/Flux.jl/issues/1209

metanoid commented 4 years ago

Correct, working correctly on Flux#stateful-map branch, thanks