SciML / DiffEqFlux.jl

Pre-built implicit layer architectures with O(1) backprop, GPUs, and stiff+non-stiff DE solvers, demonstrating scientific machine learning (SciML) and physics-informed machine learning methods
https://docs.sciml.ai/DiffEqFlux/stable
MIT License
871 stars 157 forks source link

Issue with batch training of NeuralODE with signals as inputs #225

Closed junsebas97 closed 4 years ago

junsebas97 commented 4 years ago

I’m trying to use the DiffEqFlux package to define a machine learning model for an ODE. The model I need is a Neural ODE whose input is the initial condition and a continuos time signal.

formula)

Currently, i have been capable to train the NN with just one example with:

dydt  = FastChain(FastDense(3,  10), FastDense(10,  2))
dydtₚ = initial_params(dydt)

function Node(du, u, p, t)

    f_t    = p[1](t)
    params = p[2]

    du[1] = dydt([u; f_t], params)[1]
    du[2] = dydt([u; f_t], params)[2]
end

function prediction(fₜ, θ)
    prob =  ODEProblem(Node, u0, tspan, [fₜ, θ])
    concrete_solve(prob, Tsit5(), saveat = Δt)
end

loss_fn = sum(abs2, solution - prediction(f_t, dydtₚ))
res = DiffEqFlux.sciml_train(loss_fn, dydtₚ, LBFGS())

That works prefectly, but it uses just one signal for training. When i try to train in a batch of functions (using several signals) with:

function loss_batch(θ)

    n_exam   = length(signals)
    loss_val = 0

    for i in 1:n_exam
        guess     = prediction(signal[i], θ)
        loss_val += sum(abs2, solution[i] - guess)
    end

    loss_val
end

res = DiffEqFlux.sciml_train(loss_batch, dydtₚ, LBFGS())

I get this message:

ERROR: LoadError: MethodError: no method matching AbstractFloat(::var"#19#29")
Closest candidates are:
  AbstractFloat(::Bool) at float.jl:258 
  AbstractFloat(::Int8) at float.jl:259 
  AbstractFloat(::Int16) at float.jl:260

I think that the bug is produced by the cycle, but i'm not sure

ChrisRackauckas commented 4 years ago

The right thing to use is probably a closure. Here's an example of that:

using DiffEqFlux, OrdinaryDiffEq, Flux
dydt  = FastChain(FastDense(3,  10), FastDense(10,  2))
dydtₚ = initial_params(dydt)

function Node(u, p, t, f_t)
    dydt([u; f_t(t)], p)
end

function prediction(fₜ, θ)
    prob =  ODEProblem((u,p,t)->Node(u,p,t,fₜ), u0, tspan, θ)
    concrete_solve(prob, Tsit5(), saveat = Δt)
end
solution = ones(2,11)

f_t = (t)->t;
u0 = ones(2)
Δt = 0.1
tspan = (0.0,1.0)
loss(dydtₚ) = sum(abs2, solution - prediction(f_t, dydtₚ))
res = DiffEqFlux.sciml_train(loss, dydtₚ, LBFGS())

solution = [ones(2,11) for i in 1:10]
signals = [(t)->-i*t for i in 1:10]
function loss_batch(θ)

    n_exam   = length(signals)
    loss_val = 0

    for i in 1:n_exam
        guess     = prediction(signals[i], θ)
        loss_val += sum(abs2, solution[i] - guess)
    end

    loss_val
end

res = DiffEqFlux.sciml_train(loss_batch, dydtₚ, LBFGS())

Let me know if you need anything else. Cheers.

junsebas97 commented 4 years ago

Hi, I tryed to implement closure in my code and run your example too, but always i get this error

LoadError: MethodError: no method matching iterate(::Val{1})
Closest candidates are:
  iterate(!Matched::Core.SimpleVector) at essentials.jl:603
  iterate(!Matched::Core.SimpleVector, !Matched::Any) at essentials.jl:603
  iterate(!Matched::ExponentialBackOff) at error.jl:253

Actually, with the closure the train can't be performed with just one signal

ChrisRackauckas commented 4 years ago

Are you saying that example code didn't work for you? Or are you talking about some other code/

junsebas97 commented 4 years ago

Both, i tried to implement the closure in other code and it didn't work, then i wrote literally the example you gave me and it didn't work.

Always i get the same message

LoadError: MethodError: no method matching iterate(::Val{1})
Closest candidates are:
  iterate(!Matched::Core.SimpleVector) at essentials.jl:603
  iterate(!Matched::Core.SimpleVector, !Matched::Any) at essentials.jl:603
  iterate(!Matched::ExponentialBackOff) at error.jl:25
ChrisRackauckas commented 4 years ago

Can you show ]st and ]st -m? I was testing that from Julia v1.4.1

junsebas97 commented 4 years ago

I use JuliaPro 1.4.0-1

]st

Status `C:\Users\asus\.juliapro\JuliaPro_v1.4.0-1\environments\v1.4\Project.toml`
  [c52e3926] Atom v0.12.10 ⚲
  [aae7a2af] DiffEqFlux v1.8.1
  [41bf760c] DiffEqSensitivity v6.10.2
  [0c46a032] DifferentialEquations v6.12.0
  [587475ba] Flux v0.10.3
  [7073ff75] IJulia v1.20.0
  [e5e0dc1b] Juno v0.8.1 ⚲
  [429524aa] Optim v0.20.6
  [1dea7af3] OrdinaryDiffEq v5.32.1
  [4722fa14] PkgAuthentication v0.1.2
  [91a5bcdd] Plots v0.29.9
  [d330b81b] PyPlot v2.8.2
  [e88e6eb3] Zygote v0.4.8

]st -m

Status `C:\Users\asus\.juliapro\JuliaPro_v1.4.0-1\environments\v1.4\Manifest.toml`
  [621f4979] AbstractFFTs v0.5.0
  [1520ce14] AbstractTrees v0.3.2
  [79e6a3ab] Adapt v1.0.1
  [ec485272] ArnoldiMethod v0.0.4
  [7d9fca2a] Arpack v0.4.0
  [68821587] Arpack_jll v3.5.0+3
  [4fba245c] ArrayInterface v2.6.2
  [4c555306] ArrayLayouts v0.2.1
  [bf4720bc] AssetRegistry v0.1.0
  [c52e3926] Atom v0.12.10 ⚲
  [aae01518] BandedMatrices v0.15.1
  [b99e7846] BinaryProvider v0.5.91
  [a134a8b2] BlackBoxOptim v0.5.0
  [764a87c0] BoundaryValueDiffEq v2.3.0
  [6e34b625] Bzip2_jll v1.0.6+2
  [fa961155] CEnum v0.2.0
  [a9c8d775] CPUTime v1.0.0
  [00ebfdb7] CSTParser v2.2.0
  [3895d2a7] CUDAapi v3.1.0
  [c5f51814] CUDAdrv v6.0.0
  [be33ccc6] CUDAnative v2.10.2
  [a603d957] CanonicalTraits v0.2.1
  [7057c7e9] Cassette v0.3.1
  [d360d2e6] ChainRulesCore v0.7.1
  [53a63b46] CodeTools v0.7.0
  [da1fd8a2] CodeTracking v0.5.8
  [944b1d66] CodecZlib v0.6.0
  [3da002f7] ColorTypes v0.9.1
  [5ae59095] Colors v0.11.2
  [bbf7d656] CommonSubexpressions v0.2.0
  [34da2185] Compat v3.6.0
  [e66e0078] CompilerSupportLibraries_jll v0.3.3+0
  [8f4d0f93] Conda v1.4.1
  [88cd18e8] ConsoleProgressMonitor v0.1.2
  [d38c429a] Contour v0.5.2
  [3a865a2d] CuArrays v1.7.3
  [9a962f9c] DataAPI v1.1.0
  [864edb3b] DataStructures v0.17.10
  [bcd4f6db] DelayDiffEq v5.23.0
  [2b5f629d] DiffEqBase v6.25.1
  [459566f4] DiffEqCallbacks v2.12.1
  [01453d9d] DiffEqDiffTools v1.7.0
  [5a0ffddc] DiffEqFinancial v2.2.1
  [aae7a2af] DiffEqFlux v1.8.1
  [c894b116] DiffEqJump v6.5.0
  [77a26b50] DiffEqNoiseProcess v3.9.0
  [055956cb] DiffEqPhysics v3.5.0
  [41bf760c] DiffEqSensitivity v6.10.2
  [163ba53b] DiffResults v1.0.2
  [b552c78f] DiffRules v1.0.1
  [0c46a032] DifferentialEquations v6.12.0
  [c619ae07] DimensionalPlotRecipes v1.1.0
  [b4f34e82] Distances v0.8.2
  [31c24e10] Distributions v0.23.2
  [33d173f1] DocSeeker v0.4.1
  [ffbed154] DocStringExtensions v0.8.1
  [d4d017d3] ExponentialUtilities v1.6.0
  [c87230d0] FFMPEG v0.3.0
  [b22a6f82] FFMPEG_jll v4.1.0+2
  [7a1cc6ca] FFTW v1.2.0
  [f5851436] FFTW_jll v3.3.9+5
  [5789e2e9] FileIO v1.2.2
  [1a297f60] FillArrays v0.8.6
  [6a86dc24] FiniteDiff v2.3.0
  [53c48c17] FixedPointNumbers v0.7.1
  [08572546] FlameGraphs v0.2.0
  [587475ba] Flux v0.10.3
  [59287772] Formatting v0.4.1
  [f6369f11] ForwardDiff v0.10.10
  [d7e528f0] FreeType2_jll v2.10.1+2
  [559328eb] FriBidi_jll v1.0.5+2
  [069b7b12] FunctionWrappers v1.1.1
  [de31a74c] FunctionalCollections v0.5.0
  [0c68f7d7] GPUArrays v2.0.1
  [28b8d3ca] GR v0.46.90
  [6b9d7cbe] GeneralizedGenerated v0.2.2
  [01680d73] GenericSVD v0.3.0
  [4d00f742] GeometryTypes v0.8.2
  [cd3eb016] HTTP v0.8.12
  [9fb69e20] Hiccup v0.2.2
  [7073ff75] IJulia v1.20.0
  [7869d1d1] IRTools v0.3.1
  [9b13fd28] IndirectArrays v0.5.1
  [d25df0c9] Inflate v0.1.2
  [83e8ac13] IniFile v0.5.0
  [1d5cc7b8] IntelOpenMP_jll v2018.0.3+0
  [42fd0dbc] IterativeSolvers v0.8.3
  [82899510] IteratorInterfaceExtensions v1.0.0
  [682c06a0] JSON v0.21.0
  [98e50ef6] JuliaFormatter v0.3.6
  [aa1ae85d] JuliaInterpreter v0.7.12
  [b14d175d] JuliaVariables v0.2.0
  [e5e0dc1b] Juno v0.8.1 ⚲
  [c1c5ebd0] LAME_jll v3.100.0+0
  [929cbde3] LLVM v1.3.4
  [7c4cb9fa] LNR v0.2.1
  [b964fa9f] LaTeXStrings v1.1.0
  [23fbe1c1] Latexify v0.13.0
  [a5e1c1ea] LatinHypercubeSampling v1.2.0
  [50d2b5c4] Lazy v0.15.0
  [1d6d02ad] LeftChildRightSiblingTrees v0.1.2
  [dd192d2f] LibVPX_jll v1.8.1+1
  [093fc24a] LightGraphs v1.3.1
  [d3d80556] LineSearches v7.0.1
  [e6f89c97] LoggingExtras v0.4.0
  [856f044c] MKL_jll v2019.0.117+2
  [d8e11817] MLStyle v0.3.1
  [1914dd2f] MacroTools v0.5.4
  [739be429] MbedTLS v1.0.0
  [c8ffd9c3] MbedTLS_jll v2.16.0+1
  [442fdcdd] Measures v0.3.1
  [e89f7d12] Media v0.5.0
  [e1d29d7a] Missings v0.4.3
  [961ee093] ModelingToolkit v1.4.2
  [46d2c3a1] MuladdMacro v0.2.2
  [f9640e96] MultiScaleArrays v1.6.0
  [d41bc354] NLSolversBase v7.6.1
  [2774e3e8] NLsolve v4.3.0
  [872c559c] NNlib v0.6.6
  [77ba4419] NaNMath v0.3.3
  [71a1bf82] NameResolution v0.1.3
  [510215fc] Observables v0.3.1
  [e7412a2a] Ogg_jll v1.3.3+0
  [4536629a] OpenBLAS_jll v0.3.9+0
  [458c3c95] OpenSSL_jll v1.1.1+2
  [efe28fd5] OpenSpecFun_jll v0.5.3+3
  [429524aa] Optim v0.20.6
  [91d4177d] Opus_jll v1.3.1+0
  [bac558e1] OrderedCollections v1.1.0
  [1dea7af3] OrdinaryDiffEq v5.32.1
  [90014a1f] PDMats v0.9.12
  [65888b18] ParameterizedFunctions v5.0.3
  [d96e819e] Parameters v0.12.0
  [69de0a69] Parsers v0.3.12
  [fa939f87] Pidfile v1.1.0
  [4722fa14] PkgAuthentication v0.1.2
  [ccf2f8ad] PlotThemes v1.0.2
  [995b91a9] PlotUtils v0.6.5
  [91a5bcdd] Plots v0.29.9
  [e409e4f3] PoissonRandom v0.4.0
  [85a6dd25] PositiveFactorizations v0.2.3
  [8162dcfd] PrettyPrint v0.1.0
  [33c8b6b6] ProgressLogging v0.1.2
  [92933f4c] ProgressMeter v1.2.0
  [438e738f] PyCall v1.91.4
  [d330b81b] PyPlot v2.8.2
  [1fd47b50] QuadGK v2.3.1
  [8a4e6c94] QuasiMonteCarlo v0.1.2
  [e6cf234a] RandomNumbers v1.4.0
  [3cdcf5f2] RecipesBase v0.8.0
  [731186ca] RecursiveArrayTools v2.1.2
  [f2c3362d] RecursiveFactorization v0.1.0
  [189a3867] Reexport v0.2.0
  [ae029012] Requires v1.0.1
  [ae5879a3] ResettableStacks v1.0.0
  [37e2e3b7] ReverseDiff v1.1.0
  [79098fc4] Rmath v0.6.1
  [f50d1b31] Rmath_jll v0.2.2+0
  [f2b01f46] Roots v1.0.1
  [992d4aef] Showoff v0.3.1
  [699a6c99] SimpleTraits v0.9.1
  [ed01d8cd] Sobol v1.3.0
  [b85f4697] SoftGlobalScope v1.0.10
  [a2af1166] SortingAlgorithms v0.3.1
  [47a9eef4] SparseDiffTools v1.4.0
  [d4ead438] SpatialIndexing v0.1.2
  [276daf66] SpecialFunctions v0.10.0
  [90137ffa] StaticArrays v0.12.1
  [2913bbd2] StatsBase v0.33.0
  [4c63d2b9] StatsFuns v0.9.4
  [9672c7b4] SteadyStateDiffEq v1.5.0
  [789caeaf] StochasticDiffEq v6.19.1
  [88034a9c] StringDistances v0.6.3
  [c3572dad] Sundials v3.9.0
  [3783bdb8] TableTraits v1.0.0
  [5d786b92] TerminalLoggers v0.1.1
  [a759f4b9] TimerOutputs v0.5.3
  [0796e94c] Tokenize v0.5.7
  [37b6cedf] Traceur v0.3.0
  [9f7883ad] Tracker v0.2.6
  [3bb67fe8] TranscodingStreams v0.9.5
  [a2a6695c] TreeViews v0.3.0
  [30578b45] URIParser v0.4.0
  [3a884ed6] UnPack v0.1.0
  [81def892] VersionParsing v1.2.0
  [19fa3120] VertexSafeGraphs v0.1.1
  [0f1e0344] WebIO v0.8.92
  [104b5d7c] WebSockets v1.5.2
  [cc8bc4a8] Widgets v0.6.2
  [c2297ded] ZMQ v1.2.0
  [8f1865be] ZeroMQ_jll v4.3.2+1
  [a5390f91] ZipFile v0.9.1
  [83775a58] Zlib_jll v1.2.11+9
  [e88e6eb3] Zygote v0.4.8
  [700de1a5] ZygoteRules v0.2.0
  [0ac62f75] libass_jll v0.14.0+0
  [f638f0a6] libfdk_aac_jll v0.1.6+1
  [f27f6e37] libvorbis_jll v1.3.6+2
  [1270edf5] x264_jll v2019.5.25+1
  [dfaa095f] x265_jll v3.0.0+0
  [2a0f44e3] Base64 
  [ade2ca70] Dates
  [8bb1440f] DelimitedFiles
  [8ba89e20] Distributed 
  [7b1f6079] FileWatching 
  [b77e0a4c] InteractiveUtils
  [76f85450] LibGit2 
  [8f399da3] Libdl
  [37e2e46d] LinearAlgebra
  [56ddb016] Logging 
  [d6f4376e] Markdown
  [a63ad114] Mmap
  [44cfe95a] Pkg 
  [de0858da] Printf 
  [9abbd945] Profile
  [3fa0cd96] REPL
  [9a3f8284] Random
  [ea8e919c] SHA
  [9e88b42a] Serialization 
  [1a1011a3] SharedArrays
  [6462fe0b] Sockets 
  [2f01184e] SparseArrays
  [10745b16] Statistics
  [4607b0f0] SuiteSparse
  [8dfed614] Test
  [cf7118a7] UUIDs 
  [4ec0a83e] Unicode 
ChrisRackauckas commented 4 years ago

Your Zygote version is much behind. Try updating that and see if that fixes it.

junsebas97 commented 4 years ago

Yeah, it works with the current version of Zygote. Thanks!!

A final question, as the signal varies in time it must be evaluated at every step. does it happens with the closure?

ChrisRackauckas commented 4 years ago

A final question, as the signal varies in time it must be evaluated at every step. does it happens with the closure?

Since you're enclosing the function.

junsebas97 commented 4 years ago

Ok thanks!