SciML / DiffEqFlux.jl

Pre-built implicit layer architectures with O(1) backprop, GPUs, and stiff+non-stiff DE solvers, demonstrating scientific machine learning (SciML) and physics-informed machine learning methods
https://docs.sciml.ai/DiffEqFlux/stable
MIT License
865 stars 153 forks source link

StackOverflowError on a simple example #149

Closed loiseaujc closed 4 years ago

loiseaujc commented 4 years ago

Hi guys,

I have been able to run successfully the Lotka-Voltera example. I am now looking at a classical benchmark of optimal control, namely the inverted pendulum on a cart where my controller is given by a fairly simple neural net (just to give it a try). Unfortunately, whenever I try to train the model, I get a StackOverflowError and I am not sure how to solve my problem.

Below is my code.

using DifferentialEquations, DiffEqFlux
using Optim
using LinearAlgebra

# --> Parameters of the cart.
m1 = 0.9
k1 = 0.1

# --> Parameters of the pendulum.
m2 = 0.1
k2 = 0.1
l = 10
g = 9.818

# --> Pack parameters.
p = [m1, m2, k1, k2, l, g]

# --> Define controller.
K = FastChain(
      FastDense(4, 4, tanh),
     FastDense(4, 1, identitity)
)

W = initial_params(K)

# --> Dynamics of the cartplot.
function controlled_pendulum!(du, u, p, t)

    # --> Unpack variables.
    x, θ, dx, dθ = u

    # --> Unpack physical parameters.
    m₁, m₂, k₁, k₂, l, g = p[1:6]

    # --> Unpack neural net parameters.
    W = p[7:end]

    # --> Mass matrix.
    M = [
        1 0 0 0;
        0 1 0 0;
        0 0 m₁ + m₂ -m₂*l*cos(θ);
        0 0 -cos(θ) l]

    # --> Left-hand side of the equations of motion.
    du[1] = dx
    du[2] = dθ
    du[3] = -m₂*l*dθ^2*sin(θ) - k₁*dx + K([x, θ, dx, dθ], W)[1]
    du[4] = g*sin(θ) - k₂*dθ

    # --> Invert mass matrix.
    du = M \ du

    return du
end

# --> Setup the ODE problem.
u₀ = Float32[0.0, 0.1, 0.0, 0.0]
params = Float32[p ; W]

tspan = (0.0f0, 10.0f0)
prob = ODEProblem(controlled_pendulum!, u₀, tspan; p=params)

# --> Predict function.
function predict(θ)
   Array(concrete_solve(prob, Tsit5(), u₀, [p; θ], saveat=0.1))
end

# --> Loss and callback functions.
function loss(θ)
    return norm(predict(θ))^2
end

l = loss(W)

cb = function(θ, l)
    println(l)
    return false
end

# --> Train the controller.
res = DiffEqFlux.sciml_train(loss, W, BFGS(), cb=cb)

Whenever I run this code, here is the error I get :

StackOverflowError:

Stacktrace:
 [1] promote_rule(::Type{Tracker.TrackedReal{Tracker.TrackedReal{Tracker.TrackedReal{Float64}}}}, ::Type{Tracker.TrackedReal{Float64}}) at /home/loiseau/.julia/packages/Tracker/cpxco/src/lib/real.jl:61
 [2] promote_type(::Type{Tracker.TrackedReal{Float64}}, ::Type{Tracker.TrackedReal{Tracker.TrackedReal{Tracker.TrackedReal{Float64}}}}) at ./promotion.jl:223
 [3] promote_rule(::Type{Tracker.TrackedReal{Tracker.TrackedReal{Float64}}}, ::Type{Tracker.TrackedReal{Tracker.TrackedReal{Tracker.TrackedReal{Float64}}}}) at /home/loiseau/.julia/packages/Tracker/cpxco/src/lib/real.jl:61
 [4] promote_type at ./promotion.jl:223 [inlined]
 [5] promote_result(::Type, ::Type, ::Type{Tracker.TrackedReal{Tracker.TrackedReal{Float64}}}, ::Type{Tracker.TrackedReal{Tracker.TrackedReal{Tracker.TrackedReal{Float64}}}}) at ./promotion.jl:237
 [6] promote_type(::Type{Tracker.TrackedReal{Tracker.TrackedReal{Float64}}}, ::Type{Tracker.TrackedReal{Float64}}) at ./promotion.jl:223
 ... (the last 6 lines are repeated 15998 more times)
 [95995] promote_rule(::Type{Tracker.TrackedReal{Tracker.TrackedReal{Tracker.TrackedReal{Float64}}}}, ::Type{Tracker.TrackedReal{Float64}}) at /home/loiseau/.julia/packages/Tracker/cpxco/src/lib/real.jl:61
 [95996] promote_type(::Type{Tracker.TrackedReal{Float64}}, ::Type{Tracker.TrackedReal{Tracker.TrackedReal{Tracker.TrackedReal{Float64}}}}) at ./promotion.jl:223
 [95997] promote_rule(::Type{Tracker.TrackedReal{Tracker.TrackedReal{Float64}}}, ::Type{Tracker.TrackedReal{Tracker.TrackedReal{Tracker.TrackedReal{Float64}}}}) at /home/loiseau/.julia/packages/Tracker/cpxco/src/lib/real.jl:61
 [95998] promote_type at ./promotion.jl:223 [inlined]
 [95999] promote_result(::Type, ::Type, ::Type{Tracker.TrackedReal{Tracker.TrackedReal{Float64}}}, ::Type{Tracker.TrackedReal{Tracker.TrackedReal{Tracker.TrackedReal{Float64}}}}) at ./promotion.jl:237

I am pretty new to Julia so any help is welcome :)

++ JC

ChrisRackauckas commented 4 years ago

You hit an AD bug. I think it might be easier to just use the out-of-place (u,p,t) style instead of mutating (you had that syntax wrong anyways). With that, here's a workaround:

using DifferentialEquations, DiffEqFlux
using Optim
using LinearAlgebra

# --> Parameters of the cart.
m1 = 0.9
k1 = 0.1

# --> Parameters of the pendulum.
m2 = 0.1
k2 = 0.1
l = 10
g = 9.818

# --> Pack parameters.
p = [m1, m2, k1, k2, l, g]

# --> Define controller.
K = FastChain(
      FastDense(4, 4, tanh),
     FastDense(4, 1)
)

W = initial_params(K)

# --> Dynamics of the cartplot.
function controlled_pendulum!(u, p, t)

    # --> Unpack variables.
    x, θ, dx, dθ = u

    # --> Unpack physical parameters.
    m₁, m₂, k₁, k₂, l, g = p[1:6]

    # --> Unpack neural net parameters.
    W = p[7:end]

    # --> Left-hand side of the equations of motion.
    du = [dx,dθ,-m₂*l*dθ^2*sin(θ) - k₁*dx + K(u, W)[1],g*sin(θ) - k₂*dθ]

    # --> Mass matrix.
    M = reshape([1,0,0,0,
         0,1,0,0,
         0,0,m₁ + m₂,-m₂*l*cos(θ),
         0,0,-cos(θ),l],4,4)

    # --> Invert mass matrix.
    return M \ du
end

# --> Setup the ODE problem.
u₀ = Float32[0.0, 0.1, 0.0, 0.0]
params = Float32[p ; W]

tspan = (0.0f0, 10.0f0)
prob = ODEProblem(controlled_pendulum!, u₀, tspan; p=params)

# --> Predict function.
function predict(θ)
   Array(concrete_solve(prob, Tsit5(), u₀, [p; θ], saveat=0.1))
end

# --> Loss and callback functions.
function loss(θ)
    return norm(predict(θ))^2
end

l = loss(W)

cb = function(θ, l)
    println(l)
    return false
end

# --> Train the controller.
res = DiffEqFlux.sciml_train(loss, W, BFGS(), cb=cb)

Note your model is unstable, so you will need to play around with it (maybe change the initial condition on the controlled variable or something). But that weird mass matrix syntax is the workaround for the bug, and I'll get that upstreamed and fixed.

ChrisRackauckas commented 4 years ago

This is what the reshape is working around https://github.com/FluxML/Zygote.jl/issues/513

ChrisRackauckas commented 4 years ago

Fixed upstream.