Closed junsebas97 closed 4 years ago
The right thing to use is probably a closure. Here's an example of that:
using DiffEqFlux, OrdinaryDiffEq, Flux
dydt = FastChain(FastDense(3, 10), FastDense(10, 2))
dydtₚ = initial_params(dydt)
function Node(u, p, t, f_t)
dydt([u; f_t(t)], p)
end
function prediction(fₜ, θ)
prob = ODEProblem((u,p,t)->Node(u,p,t,fₜ), u0, tspan, θ)
concrete_solve(prob, Tsit5(), saveat = Δt)
end
solution = ones(2,11)
f_t = (t)->t;
u0 = ones(2)
Δt = 0.1
tspan = (0.0,1.0)
loss(dydtₚ) = sum(abs2, solution - prediction(f_t, dydtₚ))
res = DiffEqFlux.sciml_train(loss, dydtₚ, LBFGS())
solution = [ones(2,11) for i in 1:10]
signals = [(t)->-i*t for i in 1:10]
function loss_batch(θ)
n_exam = length(signals)
loss_val = 0
for i in 1:n_exam
guess = prediction(signals[i], θ)
loss_val += sum(abs2, solution[i] - guess)
end
loss_val
end
res = DiffEqFlux.sciml_train(loss_batch, dydtₚ, LBFGS())
Let me know if you need anything else. Cheers.
Hi, I tryed to implement closure in my code and run your example too, but always i get this error
LoadError: MethodError: no method matching iterate(::Val{1})
Closest candidates are:
iterate(!Matched::Core.SimpleVector) at essentials.jl:603
iterate(!Matched::Core.SimpleVector, !Matched::Any) at essentials.jl:603
iterate(!Matched::ExponentialBackOff) at error.jl:253
Actually, with the closure the train can't be performed with just one signal
Are you saying that example code didn't work for you? Or are you talking about some other code/
Both, i tried to implement the closure in other code and it didn't work, then i wrote literally the example you gave me and it didn't work.
Always i get the same message
LoadError: MethodError: no method matching iterate(::Val{1})
Closest candidates are:
iterate(!Matched::Core.SimpleVector) at essentials.jl:603
iterate(!Matched::Core.SimpleVector, !Matched::Any) at essentials.jl:603
iterate(!Matched::ExponentialBackOff) at error.jl:25
Can you show ]st
and ]st -m
? I was testing that from Julia v1.4.1
I use JuliaPro 1.4.0-1
]st
Status `C:\Users\asus\.juliapro\JuliaPro_v1.4.0-1\environments\v1.4\Project.toml`
[c52e3926] Atom v0.12.10 ⚲
[aae7a2af] DiffEqFlux v1.8.1
[41bf760c] DiffEqSensitivity v6.10.2
[0c46a032] DifferentialEquations v6.12.0
[587475ba] Flux v0.10.3
[7073ff75] IJulia v1.20.0
[e5e0dc1b] Juno v0.8.1 ⚲
[429524aa] Optim v0.20.6
[1dea7af3] OrdinaryDiffEq v5.32.1
[4722fa14] PkgAuthentication v0.1.2
[91a5bcdd] Plots v0.29.9
[d330b81b] PyPlot v2.8.2
[e88e6eb3] Zygote v0.4.8
]st -m
Status `C:\Users\asus\.juliapro\JuliaPro_v1.4.0-1\environments\v1.4\Manifest.toml`
[621f4979] AbstractFFTs v0.5.0
[1520ce14] AbstractTrees v0.3.2
[79e6a3ab] Adapt v1.0.1
[ec485272] ArnoldiMethod v0.0.4
[7d9fca2a] Arpack v0.4.0
[68821587] Arpack_jll v3.5.0+3
[4fba245c] ArrayInterface v2.6.2
[4c555306] ArrayLayouts v0.2.1
[bf4720bc] AssetRegistry v0.1.0
[c52e3926] Atom v0.12.10 ⚲
[aae01518] BandedMatrices v0.15.1
[b99e7846] BinaryProvider v0.5.91
[a134a8b2] BlackBoxOptim v0.5.0
[764a87c0] BoundaryValueDiffEq v2.3.0
[6e34b625] Bzip2_jll v1.0.6+2
[fa961155] CEnum v0.2.0
[a9c8d775] CPUTime v1.0.0
[00ebfdb7] CSTParser v2.2.0
[3895d2a7] CUDAapi v3.1.0
[c5f51814] CUDAdrv v6.0.0
[be33ccc6] CUDAnative v2.10.2
[a603d957] CanonicalTraits v0.2.1
[7057c7e9] Cassette v0.3.1
[d360d2e6] ChainRulesCore v0.7.1
[53a63b46] CodeTools v0.7.0
[da1fd8a2] CodeTracking v0.5.8
[944b1d66] CodecZlib v0.6.0
[3da002f7] ColorTypes v0.9.1
[5ae59095] Colors v0.11.2
[bbf7d656] CommonSubexpressions v0.2.0
[34da2185] Compat v3.6.0
[e66e0078] CompilerSupportLibraries_jll v0.3.3+0
[8f4d0f93] Conda v1.4.1
[88cd18e8] ConsoleProgressMonitor v0.1.2
[d38c429a] Contour v0.5.2
[3a865a2d] CuArrays v1.7.3
[9a962f9c] DataAPI v1.1.0
[864edb3b] DataStructures v0.17.10
[bcd4f6db] DelayDiffEq v5.23.0
[2b5f629d] DiffEqBase v6.25.1
[459566f4] DiffEqCallbacks v2.12.1
[01453d9d] DiffEqDiffTools v1.7.0
[5a0ffddc] DiffEqFinancial v2.2.1
[aae7a2af] DiffEqFlux v1.8.1
[c894b116] DiffEqJump v6.5.0
[77a26b50] DiffEqNoiseProcess v3.9.0
[055956cb] DiffEqPhysics v3.5.0
[41bf760c] DiffEqSensitivity v6.10.2
[163ba53b] DiffResults v1.0.2
[b552c78f] DiffRules v1.0.1
[0c46a032] DifferentialEquations v6.12.0
[c619ae07] DimensionalPlotRecipes v1.1.0
[b4f34e82] Distances v0.8.2
[31c24e10] Distributions v0.23.2
[33d173f1] DocSeeker v0.4.1
[ffbed154] DocStringExtensions v0.8.1
[d4d017d3] ExponentialUtilities v1.6.0
[c87230d0] FFMPEG v0.3.0
[b22a6f82] FFMPEG_jll v4.1.0+2
[7a1cc6ca] FFTW v1.2.0
[f5851436] FFTW_jll v3.3.9+5
[5789e2e9] FileIO v1.2.2
[1a297f60] FillArrays v0.8.6
[6a86dc24] FiniteDiff v2.3.0
[53c48c17] FixedPointNumbers v0.7.1
[08572546] FlameGraphs v0.2.0
[587475ba] Flux v0.10.3
[59287772] Formatting v0.4.1
[f6369f11] ForwardDiff v0.10.10
[d7e528f0] FreeType2_jll v2.10.1+2
[559328eb] FriBidi_jll v1.0.5+2
[069b7b12] FunctionWrappers v1.1.1
[de31a74c] FunctionalCollections v0.5.0
[0c68f7d7] GPUArrays v2.0.1
[28b8d3ca] GR v0.46.90
[6b9d7cbe] GeneralizedGenerated v0.2.2
[01680d73] GenericSVD v0.3.0
[4d00f742] GeometryTypes v0.8.2
[cd3eb016] HTTP v0.8.12
[9fb69e20] Hiccup v0.2.2
[7073ff75] IJulia v1.20.0
[7869d1d1] IRTools v0.3.1
[9b13fd28] IndirectArrays v0.5.1
[d25df0c9] Inflate v0.1.2
[83e8ac13] IniFile v0.5.0
[1d5cc7b8] IntelOpenMP_jll v2018.0.3+0
[42fd0dbc] IterativeSolvers v0.8.3
[82899510] IteratorInterfaceExtensions v1.0.0
[682c06a0] JSON v0.21.0
[98e50ef6] JuliaFormatter v0.3.6
[aa1ae85d] JuliaInterpreter v0.7.12
[b14d175d] JuliaVariables v0.2.0
[e5e0dc1b] Juno v0.8.1 ⚲
[c1c5ebd0] LAME_jll v3.100.0+0
[929cbde3] LLVM v1.3.4
[7c4cb9fa] LNR v0.2.1
[b964fa9f] LaTeXStrings v1.1.0
[23fbe1c1] Latexify v0.13.0
[a5e1c1ea] LatinHypercubeSampling v1.2.0
[50d2b5c4] Lazy v0.15.0
[1d6d02ad] LeftChildRightSiblingTrees v0.1.2
[dd192d2f] LibVPX_jll v1.8.1+1
[093fc24a] LightGraphs v1.3.1
[d3d80556] LineSearches v7.0.1
[e6f89c97] LoggingExtras v0.4.0
[856f044c] MKL_jll v2019.0.117+2
[d8e11817] MLStyle v0.3.1
[1914dd2f] MacroTools v0.5.4
[739be429] MbedTLS v1.0.0
[c8ffd9c3] MbedTLS_jll v2.16.0+1
[442fdcdd] Measures v0.3.1
[e89f7d12] Media v0.5.0
[e1d29d7a] Missings v0.4.3
[961ee093] ModelingToolkit v1.4.2
[46d2c3a1] MuladdMacro v0.2.2
[f9640e96] MultiScaleArrays v1.6.0
[d41bc354] NLSolversBase v7.6.1
[2774e3e8] NLsolve v4.3.0
[872c559c] NNlib v0.6.6
[77ba4419] NaNMath v0.3.3
[71a1bf82] NameResolution v0.1.3
[510215fc] Observables v0.3.1
[e7412a2a] Ogg_jll v1.3.3+0
[4536629a] OpenBLAS_jll v0.3.9+0
[458c3c95] OpenSSL_jll v1.1.1+2
[efe28fd5] OpenSpecFun_jll v0.5.3+3
[429524aa] Optim v0.20.6
[91d4177d] Opus_jll v1.3.1+0
[bac558e1] OrderedCollections v1.1.0
[1dea7af3] OrdinaryDiffEq v5.32.1
[90014a1f] PDMats v0.9.12
[65888b18] ParameterizedFunctions v5.0.3
[d96e819e] Parameters v0.12.0
[69de0a69] Parsers v0.3.12
[fa939f87] Pidfile v1.1.0
[4722fa14] PkgAuthentication v0.1.2
[ccf2f8ad] PlotThemes v1.0.2
[995b91a9] PlotUtils v0.6.5
[91a5bcdd] Plots v0.29.9
[e409e4f3] PoissonRandom v0.4.0
[85a6dd25] PositiveFactorizations v0.2.3
[8162dcfd] PrettyPrint v0.1.0
[33c8b6b6] ProgressLogging v0.1.2
[92933f4c] ProgressMeter v1.2.0
[438e738f] PyCall v1.91.4
[d330b81b] PyPlot v2.8.2
[1fd47b50] QuadGK v2.3.1
[8a4e6c94] QuasiMonteCarlo v0.1.2
[e6cf234a] RandomNumbers v1.4.0
[3cdcf5f2] RecipesBase v0.8.0
[731186ca] RecursiveArrayTools v2.1.2
[f2c3362d] RecursiveFactorization v0.1.0
[189a3867] Reexport v0.2.0
[ae029012] Requires v1.0.1
[ae5879a3] ResettableStacks v1.0.0
[37e2e3b7] ReverseDiff v1.1.0
[79098fc4] Rmath v0.6.1
[f50d1b31] Rmath_jll v0.2.2+0
[f2b01f46] Roots v1.0.1
[992d4aef] Showoff v0.3.1
[699a6c99] SimpleTraits v0.9.1
[ed01d8cd] Sobol v1.3.0
[b85f4697] SoftGlobalScope v1.0.10
[a2af1166] SortingAlgorithms v0.3.1
[47a9eef4] SparseDiffTools v1.4.0
[d4ead438] SpatialIndexing v0.1.2
[276daf66] SpecialFunctions v0.10.0
[90137ffa] StaticArrays v0.12.1
[2913bbd2] StatsBase v0.33.0
[4c63d2b9] StatsFuns v0.9.4
[9672c7b4] SteadyStateDiffEq v1.5.0
[789caeaf] StochasticDiffEq v6.19.1
[88034a9c] StringDistances v0.6.3
[c3572dad] Sundials v3.9.0
[3783bdb8] TableTraits v1.0.0
[5d786b92] TerminalLoggers v0.1.1
[a759f4b9] TimerOutputs v0.5.3
[0796e94c] Tokenize v0.5.7
[37b6cedf] Traceur v0.3.0
[9f7883ad] Tracker v0.2.6
[3bb67fe8] TranscodingStreams v0.9.5
[a2a6695c] TreeViews v0.3.0
[30578b45] URIParser v0.4.0
[3a884ed6] UnPack v0.1.0
[81def892] VersionParsing v1.2.0
[19fa3120] VertexSafeGraphs v0.1.1
[0f1e0344] WebIO v0.8.92
[104b5d7c] WebSockets v1.5.2
[cc8bc4a8] Widgets v0.6.2
[c2297ded] ZMQ v1.2.0
[8f1865be] ZeroMQ_jll v4.3.2+1
[a5390f91] ZipFile v0.9.1
[83775a58] Zlib_jll v1.2.11+9
[e88e6eb3] Zygote v0.4.8
[700de1a5] ZygoteRules v0.2.0
[0ac62f75] libass_jll v0.14.0+0
[f638f0a6] libfdk_aac_jll v0.1.6+1
[f27f6e37] libvorbis_jll v1.3.6+2
[1270edf5] x264_jll v2019.5.25+1
[dfaa095f] x265_jll v3.0.0+0
[2a0f44e3] Base64
[ade2ca70] Dates
[8bb1440f] DelimitedFiles
[8ba89e20] Distributed
[7b1f6079] FileWatching
[b77e0a4c] InteractiveUtils
[76f85450] LibGit2
[8f399da3] Libdl
[37e2e46d] LinearAlgebra
[56ddb016] Logging
[d6f4376e] Markdown
[a63ad114] Mmap
[44cfe95a] Pkg
[de0858da] Printf
[9abbd945] Profile
[3fa0cd96] REPL
[9a3f8284] Random
[ea8e919c] SHA
[9e88b42a] Serialization
[1a1011a3] SharedArrays
[6462fe0b] Sockets
[2f01184e] SparseArrays
[10745b16] Statistics
[4607b0f0] SuiteSparse
[8dfed614] Test
[cf7118a7] UUIDs
[4ec0a83e] Unicode
Your Zygote version is much behind. Try updating that and see if that fixes it.
Yeah, it works with the current version of Zygote. Thanks!!
A final question, as the signal varies in time it must be evaluated at every step. does it happens with the closure?
A final question, as the signal varies in time it must be evaluated at every step. does it happens with the closure?
Since you're enclosing the function.
Ok thanks!
I’m trying to use the DiffEqFlux package to define a machine learning model for an ODE. The model I need is a Neural ODE whose input is the initial condition and a continuos time signal.
)
Currently, i have been capable to train the NN with just one example with:
That works prefectly, but it uses just one signal for training. When i try to train in a batch of functions (using several signals) with:
I get this message:
I think that the bug is produced by the cycle, but i'm not sure