SciML / DiffEqFlux.jl

Pre-built implicit layer architectures with O(1) backprop, GPUs, and stiff+non-stiff DE solvers, demonstrating scientific machine learning (SciML) and physics-informed machine learning methods
https://docs.sciml.ai/DiffEqFlux/stable
MIT License
872 stars 157 forks source link

BBO breaks #199

Closed ranjanan closed 4 years ago

ranjanan commented 4 years ago
using DiffEqFlux, OrdinaryDiffEq, Flux, Optim, Plots                                                                                                                                        

 u0 = Float32[2.; 0.]                                                                                                                                                                        
 datasize = 30                                                                                                                                                                               
 tspan = (0.0f0,1.5f0)                                                                                                                                                                       
                                                                                                                                                                                              function trueODEfunc(du,u,p,t)                                                                                                                                                                           true_A = [-0.1 2.0; -2.0 -0.1]                                                                                                                                                          
     du .= ((u.^3)'true_A)'                                                                                                                                                                   end

 t = range(tspan[1],tspan[2],length=datasize)                                                                                                                                                
 prob = ODEProblem(trueODEfunc,u0,tspan)                                                                                                                                                     
 ode_data = Array(solve(prob,Tsit5(),saveat=t))                                                                                                                                              

 dudt2 = FastChain((x,p) -> x.^3,                                                                                                                                                            
             FastDense(2,50,tanh),                                                                                                                                                           
             FastDense(50,2))                                                                                                                                                                
 n_ode = NeuralODE(dudt2,tspan,Tsit5(),saveat=t)                                                                                                                                             

 function predict_n_ode(p)                                                                                                                                                                   
   n_ode(eltype(p).(u0),p)                                                                                                                                                                   
 end                                                                                                                                                                                         

 function loss_n_ode(p)                                                                                                                                                                      
     pred = predict_n_ode(p)                                                                                                                                                                 
     loss = sum(abs2,ode_data .- pred)                                                                                                                                                       
     loss,pred                                                                                                                                                                               
 end                                                                                                                                                                                         

 loss_n_ode(n_ode.p) # n_ode.p stores the initial parameters of the neural ODE                                                                                                               

 cb = function (p,l,pred;doplot=false) #callback function to observe training                                                                                                                
   display(l)                                                                                                                                                                                
   # plot current prediction against data                                                                                                                                                    
   if doplot                                                                                                                                                                                 
     pl = scatter(t,ode_data[1,:],label="data")                                                                                                                                              
     scatter!(pl,t,pred[1,:],label="prediction")                                                                                                                                             
     display(plot(pl))                                                                                                                                                                       
   end                                                                                                                                                                                       
   return false                                                                                                                                                                              
 end                                                                                                                                                                                         

 # Display the ODE with the initial parameter values.                                                                                                                                         cb(n_ode.p,loss_n_ode(n_ode.p)...)                                                                                                                                                                                                                                                                                                                                                       
 res1 = DiffEqFlux.sciml_train(loss_n_ode, n_ode.p, DiffEqFlux.BBO(), cb = cb, maxiters = 300) 

Error:

julia> include("test.jl")
351.97626f0
351.97623f0
Training 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| Time: 0:00:16
ERROR: LoadError: MethodError: no method matching apply!(::DiffEqFlux.BBO, ::Array{Float32,1}, ::Array{Float32,1})
Closest candidates are:
  apply!(::Descent, ::Any, ::Any) at /home/ranjan/.julia/packages/Flux/NpkMm/src/optimise/optimisers.jl:39
  apply!(::Momentum, ::Any, ::Any) at /home/ranjan/.julia/packages/Flux/NpkMm/src/optimise/optimisers.jl:67
  apply!(::Nesterov, ::Any, ::Any) at /home/ranjan/.julia/packages/Flux/NpkMm/src/optimise/optimisers.jl:98
  ...
Stacktrace:
 [1] update!(::DiffEqFlux.BBO, ::Array{Float32,1}, ::Array{Float32,1}) at /home/ranjan/.julia/packages/DiffEqFlux/YKKwl/src/train.jl:19
 [2] update!(::DiffEqFlux.BBO, ::Zygote.Params, ::Zygote.Grads) at /home/ranjan/.julia/packages/DiffEqFlux/YKKwl/src/train.jl:29
 [3] (::DiffEqFlux.var"#23#28"{var"#81#83",Int64,Bool,Bool,typeof(loss_n_ode),Array{Float32,1},Zygote.Params})() at /home/ranjan/.julia/packages/DiffEqFlux/YKKwl/src/train.jl:110
 [4] with_logstate(::DiffEqFlux.var"#23#28"{var"#81#83",Int64,Bool,Bool,typeof(loss_n_ode),Array{Float32,1},Zygote.Params}, ::Base.CoreLogging.LogState) at ./logging.jl:395
 [5] with_logger at ./logging.jl:491 [inlined]
 [6] maybe_with_logger(::DiffEqFlux.var"#23#28"{var"#81#83",Int64,Bool,Bool,typeof(loss_n_ode),Array{Float32,1},Zygote.Params}, ::LoggingExtras.TeeLogger{Tuple{LoggingExtras.EarlyFilteredLogger{TerminalLoggers.TerminalLogger,DiffEqBase.var"#10#12"},LoggingExtras.EarlyFilteredLogger{Logging.ConsoleLogger,DiffEqBase.var"#11#13"}}}) at /home/ranjan/.julia/packages/DiffEqBase/k3AhB/src/utils.jl:259
 [7] #sciml_train#22(::Function, ::Int64, ::Bool, ::Bool, ::typeof(DiffEqFlux.sciml_train), ::Function, ::Array{Float32,1}, ::DiffEqFlux.BBO, ::Base.Iterators.Cycle{Tuple{DiffEqFlux.NullData}}) at /home/ranjan/.julia/packages/DiffEqFlux/YKKwl/src/train.jl:42
 [8] #sciml_train at ./none:0 [inlined] (repeats 2 times)
 [9] top-level scope at /home/ranjan/.julia/dev/ARPAEMERL/test/test.jl:46
 [10] include at ./boot.jl:328 [inlined]
 [11] include_relative(::Module, ::String) at ./loading.jl:1105
 [12] include(::Module, ::String) at ./Base.jl:31
 [13] include(::String) at ./client.jl:424
 [14] top-level scope at REPL[10]:1
in expression starting at /home/ranjan/.julia/dev/ARPAEMERL/test/test.jl:46
ranjanan commented 4 years ago

This is on DiffEqFlux 1.8.0

ranjanan commented 4 years ago

I'm working on a PR for this

abhigupta768 commented 4 years ago

Hi @ranjanan,

using DiffEqFlux, OrdinaryDiffEq, Flux, Optim, Plots                                                                                                                                        

 u0 = Float32[2.; 0.]                                                                                                                                                                        
 datasize = 30                                                                                                                                                                               
 tspan = (0.0f0,1.5f0)                                                                                                                                                                       
                                                                                                                                                                                              function trueODEfunc(du,u,p,t)                                                                                                                                                                   true_A = [-0.1 2.0; -2.0 -0.1]                                                                                                                                                          
     du .= ((u.^3)'true_A)'                                                                                                                                                                   end                                                                                                                                                                                         
 t = range(tspan[1],tspan[2],length=datasize)                                                                                                                                                
 prob = ODEProblem(trueODEfunc,u0,tspan)                                                                                                                                                     
 ode_data = Array(solve(prob,Tsit5(),saveat=t))                                                                                                                                              

 dudt2 = FastChain((x,p) -> x.^3,                                                                                                                                                            
             FastDense(2,50,tanh),                                                                                                                                                           
             FastDense(50,2))                                                                                                                                                                
 n_ode = NeuralODE(dudt2,tspan,Tsit5(),saveat=t)                                                                                                                                             

 function predict_n_ode(p)                                                                                                                                                                   
   n_ode(eltype(p).(u0),p)                                                                                                                                                                   
 end                                                                                                                                                                                         

 function loss_n_ode(p)                                                                                                                                                                      
     pred = predict_n_ode(p)                                                                                                                                                                 
     loss = sum(abs2,ode_data .- pred)                                                                                                                                                       
     loss,pred                                                                                                                                                                               
 end                                                                                                                                                                                         

 loss_n_ode(n_ode.p) # n_ode.p stores the initial parameters of the neural ODE                                                                                                               

 # Display the ODE with the initial parameter values.                                                                                                                                         cb(n_ode.p,loss_n_ode(n_ode.p)...)                                                                                                                                                                                                                                                                                                                                                       
 res1 = DiffEqFlux.sciml_train(loss_n_ode, DiffEqFlux.BBO(), maxiters = 300,lower_bounds = [-1000.0 for i in 1:252], upper_bounds = [1000.0 for i in 1:252]) 

This works, in your script you are passing an initial parameter which is not needed, callbacks are not supported with BBO and the bounds arguments are missing.

ChrisRackauckas commented 4 years ago

Can we figure out a way to make callbacks work with BBO? I thought there was a mechanism? Otherwise we should probably do some upstream changes.

ranjanan commented 4 years ago

I was calling this wrong, I changed to this:

 res1 = DiffEqFlux.sciml_train(loss_n_ode, DiffEqFlux.BBO(),                                                                                                                                                                lower_bounds = [-1 for _ in 1:length(n_ode.p)],                                                                                                               
                               upper_bounds = [1. for _ in 1:length(n_ode.p)],                                                                                                                                              cb = cb, maxiters = 300) 

and I get:

julia> include("test.jl")
321.23785f0
ERROR: LoadError: ArgumentError: Using Array{Tuple{Int64,Float64},1} for SearchRange is not supported.
Stacktrace:
 [1] check_and_create_search_space(::DictChain{Symbol,Any}) at /home/ranjan/.julia/packages/BlackBoxOptim/ZdVko/src/default_parameters.jl:71
 [2] setup_problem(::Function, ::DictChain{Symbol,Any}) at /home/ranjan/.julia/packages/BlackBoxOptim/ZdVko/src/bboptimize.jl:27
 [3] #bbsetup#86(::Base.Iterators.Pairs{Symbol,Any,NTuple{4,Symbol},NamedTuple{(:Method, :SearchRange, :MaxSteps, :cb),Tuple{Symbol,Array{Tuple{Int64,Float64},1},Int64,var"#82#84"}}}, ::typeof(bbsetup), ::Function, ::Dict{Symbol,Any}) at /home/ranjan/.julia/packages/BlackBoxOptim/ZdVko/src/bboptimize.jl:87
 [4] #bbsetup at ./none:0 [inlined]
 [5] #bboptimize#85(::Base.Iterators.Pairs{Symbol,Any,NTuple{4,Symbol},NamedTuple{(:Method, :SearchRange, :MaxSteps, :cb),Tuple{Symbol,Array{Tuple{Int64,Float64},1},Int64,var"#82#84"}}}, ::typeof(bboptimize), ::Function, ::Dict{Symbol,Any}) at /home/ranjan/.julia/packages/BlackBoxOptim/ZdVko/src/bboptimize.jl:70
 [6] #bboptimize at ./none:0 [inlined] (repeats 2 times)
 [7] #sciml_train#216(::Array{Int64,1}, ::Array{Float64,1}, ::Int64, ::Base.Iterators.Pairs{Symbol,var"#82#84",Tuple{Symbol},NamedTuple{(:cb,),Tuple{var"#82#84"}}}, ::typeof(DiffEqFlux.sciml_train), ::Function, ::DiffEqFlux.BBO, ::Base.Iterators.Cycle{Tuple{DiffEqFlux.NullData}}) at /home/ranjan/.julia/dev/DiffEqFlux/src/train.jl:328
 [8] (::DiffEqFlux.var"#kw##sciml_train")(::NamedTuple{(:lower_bounds, :upper_bounds, :cb, :maxiters),Tuple{Array{Int64,1},Array{Float64,1},var"#82#84",Int64}}, ::typeof(DiffEqFlux.sciml_train), ::Function, ::DiffEqFlux.BBO) at none:0
 [9] top-level scope at /home/ranjan/.julia/dev/ARPAEMERL/test/test.jl:46
 [10] include at ./boot.jl:328 [inlined]
 [11] include_relative(::Module, ::String) at ./loading.jl:1105
 [12] include(::Module, ::String) at ./Base.jl:31
 [13] include(::String) at ./client.jl:424
 [14] top-level scope at REPL[12]:1
in expression starting at /home/ranjan/.julia/dev/ARPAEMERL/test/test.jl:46
ChrisRackauckas commented 4 years ago

It doesn't make sense to bound the parameters of a neural network to have to be less than 1. I'm sure that's not what you meant?

abhigupta768 commented 4 years ago

Array{Tuple{Int64,Float64},1} you'll need to pass both the upper and lower bound of the same type

ranjanan commented 4 years ago

No this is my bad again, I got it to work. However the API is different from the other DiffEqFlux functions, perhaps it should be more like:

sciml_train(loss, _theta, BBO(;kwargs); kwargs)

Instead its now

sciml_train(loss, BBO(), lower_bound, uppper_bound)
ranjanan commented 4 years ago

At minimum this needs an error message

ChrisRackauckas commented 4 years ago

Yes, it should error if you don't pass both bounds. The kwargs should have no default and that would make Julia throw an error message, so something needs to change in our implementation.

ChrisRackauckas commented 4 years ago

perhaps it should be more like:

It should be more like that: right now is just version 1.

We need to really clean up and document how we handle box constraints.

abhigupta768 commented 4 years ago

I agree that having two different interfaces is a bit awkward, but this is due to the fact that BlackBoxOptim doesn't take an initial parameter value.

The kwargs should have no default and that would make Julia throw an error message,

@ChrisRackauckas it is implemented that way

ranjanan commented 4 years ago

PR nicoming

ChrisRackauckas commented 4 years ago

Interesting. It should throw a lower_bound not passed kind of error message then?

ranjanan commented 4 years ago

Yes it does

abhigupta768 commented 4 years ago

The bounds were passed, the issue was that they were not the same type

ranjanan commented 4 years ago

Yes, in my case, I made the above error ^. Happy to bikeshed on the API in #200