SciML / DiffEqFlux.jl

Pre-built implicit layer architectures with O(1) backprop, GPUs, and stiff+non-stiff DE solvers, demonstrating scientific machine learning (SciML) and physics-informed machine learning methods
https://docs.sciml.ai/DiffEqFlux/stable
MIT License
864 stars 153 forks source link

Trouble running the Neural ODE example on GPU #21

Closed mbrookhart closed 5 years ago

mbrookhart commented 5 years ago

First off, awesome paper, thank you!

When I try to run the NeuralODE example on CPU, it works great. However, when I try to switch it over to the GPU using the approach described in the paper i.e., using CuArrays and x->neural_ode(gpu(dudt),gpu(x),tspan,BS3(),saveat=0.1), I see:

ERROR: LoadError: MethodError: Base._reshape(::CuArray{Float32,1}, ::Tuple{Int64}) is ambiguous. Candidates:
  _reshape(A::GPUArrays.GPUArray{T,N} where N, dims::Tuple{Vararg{Int64,N}} where N) where T in GPUArrays at /home/brookhart/.julia/packages/GPUArrays/t8tJB/src/abstractarray.jl:230
  _reshape(A::GPUArrays.GPUArray{T,1}, dims::Tuple{Integer}) where T in GPUArrays at /home/brookhart/.julia/packages/GPUArrays/t8tJB/src/abstractarray.jl:236
  _reshape(parent::CuArray, dims::Tuple{Vararg{Int64,N}} where N) in CuArrays at /home/brookhart/.julia/packages/CuArrays/PD3UJ/src/array.jl:106
  _reshape(v::AbstractArray{T,1} where T, dims::Tuple{Int64}) in Base at reshapedarray.jl:167
Possible fix, define
  _reshape(::CuArray{T,1}, ::Tuple{Int64})
Stacktrace:
 [1] reshape(::CuArray{Float32,1}, ::Tuple{Int64}) at ./reshapedarray.jl:112
 [2] (::getfield(DiffEqFlux, Symbol("##33#34")){CuArray{Float32,1}})(::TrackedArray{…,CuArray{Float32,1}}) at /home/brookhart/.julia/packages/DiffEqFlux/1w1tX/src/Flux/utils.jl:14
 [3] #mapleaves#37(::IdDict{Any,Any}, ::Function, ::getfield(DiffEqFlux, Symbol("##33#34")){CuArray{Float32,1}}, ::TrackedArray{…,CuArray{Float32,1}}) at /home/brookhart/.julia/packages/Flux/8XpDt/src/treelike.jl:27
 [4] #mapleaves at ./none:0 [inlined]
 [5] #38 at /home/brookhart/.julia/packages/Flux/8XpDt/src/treelike.jl:27 [inlined]
 [6] _broadcast_getindex_evalf(::getfield(Flux, Symbol("##38#39")){IdDict{Any,Any},getfield(DiffEqFlux, Symbol("##33#34")){CuArray{Float32,1}}}, ::TrackedArray{…,CuArray{Float32,1}}) at ./broadcast.jl:578
 [7] _broadcast_getindex at ./broadcast.jl:551 [inlined]
 [8] (::getfield(Base.Broadcast, Symbol("##19#20")){Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple},Nothing,getfield(Flux, Symbol("##38#39")){IdDict{Any,Any},getfield(DiffEqFlux, Symbol("##33#34")){CuArray{Float32,1}}},Tuple{Tuple{TrackedArray{…,CuArray{Float32,2}},TrackedArray{…,CuArray{Float32,1}},typeof(tanh)}}}})(::Int64) at ./broadcast.jl:953
 [9] ntuple(::getfield(Base.Broadcast, Symbol("##19#20")){Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple},Nothing,getfield(Flux, Symbol("##38#39")){IdDict{Any,Any},getfield(DiffEqFlux, Symbol("##33#34")){CuArray{Float32,1}}},Tuple{Tuple{TrackedArray{…,CuArray{Float32,2}},TrackedArray{…,CuArray{Float32,1}},typeof(tanh)}}}}, ::Val{3}) at ./tuple.jl:161
 [10] copy at ./broadcast.jl:953 [inlined]
 [11] materialize at ./broadcast.jl:753 [inlined]
 [12] mapchildren(::Function, ::Dense{typeof(tanh),TrackedArray{…,CuArray{Float32,2}},TrackedArray{…,CuArray{Float32,1}}}) at /home/brookhart/.julia/packages/Flux/8XpDt/src/treelike.jl:13
 [13] #mapleaves#37(::IdDict{Any,Any}, ::Function, ::getfield(DiffEqFlux, Symbol("##33#34")){CuArray{Float32,1}}, ::Dense{typeof(tanh),TrackedArray{…,CuArray{Float32,2}},TrackedArray{…,CuArray{Float32,1}}}) at /home/brookhart/.julia/packages/Flux/8XpDt/src/treelike.jl:27
 [14] (::getfield(Flux, Symbol("#kw##mapleaves")))(::NamedTuple{(:cache,),Tuple{IdDict{Any,Any}}}, ::typeof(mapleaves), ::Function, ::Dense{typeof(tanh),TrackedArray{…,CuArray{Float32,2}},TrackedArray{…,CuArray{Float32,1}}}) at ./none:0
 [15] #38 at /home/brookhart/.julia/packages/Flux/8XpDt/src/treelike.jl:27 [inlined]
 [16] _broadcast_getindex_evalf(::getfield(Flux, Symbol("##38#39")){IdDict{Any,Any},getfield(DiffEqFlux, Symbol("##33#34")){CuArray{Float32,1}}}, ::Dense{typeof(tanh),TrackedArray{…,CuArray{Float32,2}},TrackedArray{…,CuArray{Float32,1}}}) at ./broadcast.jl:578
 [17] _broadcast_getindex at ./broadcast.jl:551 [inlined]
 [18] (::getfield(Base.Broadcast, Symbol("##19#20")){Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple},Nothing,getfield(Flux, Symbol("##38#39")){IdDict{Any,Any},getfield(DiffEqFlux, Symbol("##33#34")){CuArray{Float32,1}}},Tuple{Tuple{getfield(Main, Symbol("##3#4")),Dense{typeof(tanh),TrackedArray{…,CuArray{Float32,2}},TrackedArray{…,CuArray{Float32,1}}},Dense{typeof(identity),TrackedArray{…,CuArray{Float32,2}},TrackedArray{…,CuArray{Float32,1}}}}}}})(::Int64) at ./broadcast.jl:953
 [19] ntuple(::getfield(Base.Broadcast, Symbol("##19#20")){Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple},Nothing,getfield(Flux, Symbol("##38#39")){IdDict{Any,Any},getfield(DiffEqFlux, Symbol("##33#34")){CuArray{Float32,1}}},Tuple{Tuple{getfield(Main, Symbol("##3#4")),Dense{typeof(tanh),TrackedArray{…,CuArray{Float32,2}},TrackedArray{…,CuArray{Float32,1}}},Dense{typeof(identity),TrackedArray{…,CuArray{Float32,2}},TrackedArray{…,CuArray{Float32,1}}}}}}}, ::Val{3}) at ./tuple.jl:161
 [20] copy at ./broadcast.jl:953 [inlined]
 [21] materialize at ./broadcast.jl:753 [inlined]
 [22] mapchildren(::Function, ::Chain{Tuple{getfield(Main, Symbol("##3#4")),Dense{typeof(tanh),TrackedArray{…,CuArray{Float32,2}},TrackedArray{…,CuArray{Float32,1}}},Dense{typeof(identity),TrackedArray{…,CuArray{Float32,2}},TrackedArray{…,CuArray{Float32,1}}}}}) at /home/brookhart/.julia/packages/Flux/8XpDt/src/layers/basic.jl:28
 [23] #mapleaves#37(::IdDict{Any,Any}, ::Function, ::getfield(DiffEqFlux, Symbol("##33#34")){CuArray{Float32,1}}, ::Chain{Tuple{getfield(Main, Symbol("##3#4")),Dense{typeof(tanh),TrackedArray{…,CuArray{Float32,2}},TrackedArray{…,CuArray{Float32,1}}},Dense{typeof(identity),TrackedArray{…,CuArray{Float32,2}},TrackedArray{…,CuArray{Float32,1}}}}}) at /home/brookhart/.julia/packages/Flux/8XpDt/src/treelike.jl:27
 [24] mapleaves at /home/brookhart/.julia/packages/Flux/8XpDt/src/treelike.jl:26 [inlined]
 [25] restructure at /home/brookhart/.julia/packages/DiffEqFlux/1w1tX/src/Flux/utils.jl:12 [inlined]
 [26] (::getfield(DiffEqFlux, Symbol("#dudt_#24")){Chain{Tuple{getfield(Main, Symbol("##3#4")),Dense{typeof(tanh),TrackedArray{…,CuArray{Float32,2}},TrackedArray{…,CuArray{Float32,1}}},Dense{typeof(identity),TrackedArray{…,CuArray{Float32,2}},TrackedArray{…,CuArray{Float32,1}}}}}})(::CuArray{Float32,1}, ::CuArray{Float32,1}, ::CuArray{Float32,1}, ::Float32) at /home/brookhart/.julia/packages/DiffEqFlux/1w1tX/src/Flux/neural_de.jl:6
 [27] ODEFunction at /home/brookhart/.julia/packages/DiffEqBase/PvfXM/src/diffeqfunction.jl:106 [inlined]
 [28] initialize!(::OrdinaryDiffEq.ODEIntegrator{BS3,true,CuArray{Float32,1},Float32,CuArray{Float32,1},Float32,Float32,Float32,Array{CuArray{Float32,1},1},ODESolution{Float32,2,Array{CuArray{Float32,1},1},Nothing,Nothing,Array{Float32,1},Array{Array{CuArray{Float32,1},1},1},ODEProblem{CuArray{Float32,1},Tuple{Float32,Float32},true,CuArray{Float32,1},ODEFunction{true,getfield(DiffEqFlux, Symbol("#dudt_#24")){Chain{Tuple{getfield(Main, Symbol("##3#4")),Dense{typeof(tanh),TrackedArray{…,CuArray{Float32,2}},TrackedArray{…,CuArray{Float32,1}}},Dense{typeof(identity),TrackedArray{…,CuArray{Float32,2}},TrackedArray{…,CuArray{Float32,1}}}}}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing,DiffEqBase.StandardODEProblem},BS3,OrdinaryDiffEq.InterpolationData{ODEFunction{true,getfield(DiffEqFlux, Symbol("#dudt_#24")){Chain{Tuple{getfield(Main, Symbol("##3#4")),Dense{typeof(tanh),TrackedArray{…,CuArray{Float32,2}},TrackedArray{…,CuArray{Float32,1}}},Dense{typeof(identity),TrackedArray{…,CuArray{Float32,2}},TrackedArray{…,CuArray{Float32,1}}}}}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{CuArray{Float32,1},1},Array{Float32,1},Array{Array{CuArray{Float32,1},1},1},OrdinaryDiffEq.BS3Cache{CuArray{Float32,1},CuArray{Float32,1},CuArray{Float32,1},OrdinaryDiffEq.BS3ConstantCache{Float32,Float32}}}},ODEFunction{true,getfield(DiffEqFlux, Symbol("#dudt_#24")){Chain{Tuple{getfield(Main, Symbol("##3#4")),Dense{typeof(tanh),TrackedArray{…,CuArray{Float32,2}},TrackedArray{…,CuArray{Float32,1}}},Dense{typeof(identity),TrackedArray{…,CuArray{Float32,2}},TrackedArray{…,CuArray{Float32,1}}}}}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},OrdinaryDiffEq.BS3Cache{CuArray{Float32,1},CuArray{Float32,1},CuArray{Float32,1},OrdinaryDiffEq.BS3ConstantCache{Float32,Float32}},OrdinaryDiffEq.DEOptions{Float32,Float32,Float32,Float32,typeof(DiffEqBase.ODE_DEFAULT_NORM),typeof(LinearAlgebra.opnorm),CallbackSet{Tuple{},Tuple{}},typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN),typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE),typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK),DataStructures.BinaryHeap{Float32,DataStructures.LessThan},DataStructures.BinaryHeap{Float32,DataStructures.LessThan},Nothing,Nothing,Int64,Array{Float32,1},Float64,Array{Float32,1}},CuArray{Float32,1},Float32}, ::OrdinaryDiffEq.BS3Cache{CuArray{Float32,1},CuArray{Float32,1},CuArray{Float32,1},OrdinaryDiffEq.BS3ConstantCache{Float32,Float32}}) at /home/brookhart/.julia/packages/OrdinaryDiffEq/6mZB8/src/perform_step/low_order_rk_perform_step.jl:39
 [29] #__init#479(::Float64, ::Array{Float32,1}, ::Array{Float32,1}, ::Nothing, ::Bool, ::Nothing, ::Bool, ::Bool, ::Bool, ::Nothing, ::Bool, ::Bool, ::Float32, ::Bool, ::Rational{Int64}, ::Nothing, ::Nothing, ::Int64, ::Rational{Int64}, ::Int64, ::Int64, ::Rational{Int64}, ::Bool, ::Int64, ::Nothing, ::Nothing, ::Int64, ::Float32, ::Float32, ::typeof(DiffEqBase.ODE_DEFAULT_NORM), ::typeof(LinearAlgebra.opnorm), ::typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), ::typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), ::Bool, ::Bool, ::Bool, ::Bool, ::Bool, ::Bool, ::Bool, ::Bool, ::Int64, ::String, ::typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), ::Nothing, ::Bool, ::Bool, ::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}, ::typeof(DiffEqBase.__init), ::ODEProblem{CuArray{Float32,1},Tuple{Float32,Float32},true,CuArray{Float32,1},ODEFunction{true,getfield(DiffEqFlux, Symbol("#dudt_#24")){Chain{Tuple{getfield(Main, Symbol("##3#4")),Dense{typeof(tanh),TrackedArray{…,CuArray{Float32,2}},TrackedArray{…,CuArray{Float32,1}}},Dense{typeof(identity),TrackedArray{…,CuArray{Float32,2}},TrackedArray{…,CuArray{Float32,1}}}}}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing,DiffEqBase.StandardODEProblem}, ::BS3, ::Array{Any,1}, ::Array{Any,1}, ::Array{Any,1}, ::Type{Val{true}}) at /home/brookhart/.julia/packages/OrdinaryDiffEq/6mZB8/src/solve.jl:312
 [30] (::getfield(DiffEqBase, Symbol("#kw##__init")))(::NamedTuple{(:saveat,),Tuple{Float64}}, ::typeof(DiffEqBase.__init), ::ODEProblem{CuArray{Float32,1},Tuple{Float32,Float32},true,CuArray{Float32,1},ODEFunction{true,getfield(DiffEqFlux, Symbol("#dudt_#24")){Chain{Tuple{getfield(Main, Symbol("##3#4")),Dense{typeof(tanh),TrackedArray{…,CuArray{Float32,2}},TrackedArray{…,CuArray{Float32,1}}},Dense{typeof(identity),TrackedArray{…,CuArray{Float32,2}},TrackedArray{…,CuArray{Float32,1}}}}}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing,DiffEqBase.StandardODEProblem}, ::BS3, ::Array{Any,1}, ::Array{Any,1}, ::Array{Any,1}, ::Type{Val{true}}) at ./none:0
 [31] #__solve#478(::Base.Iterators.Pairs{Symbol,Float64,Tuple{Symbol},NamedTuple{(:saveat,),Tuple{Float64}}}, ::Function, ::ODEProblem{CuArray{Float32,1},Tuple{Float32,Float32},true,CuArray{Float32,1},ODEFunction{true,getfield(DiffEqFlux, Symbol("#dudt_#24")){Chain{Tuple{getfield(Main, Symbol("##3#4")),Dense{typeof(tanh),TrackedArray{…,CuArray{Float32,2}},TrackedArray{…,CuArray{Float32,1}}},Dense{typeof(identity),TrackedArray{…,CuArray{Float32,2}},TrackedArray{…,CuArray{Float32,1}}}}}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing,DiffEqBase.StandardODEProblem}, ::BS3, ::Array{Any,1}, ::Array{Any,1}, ::Array{Any,1}, ::Type{Val{true}}) at /home/brookhart/.julia/packages/OrdinaryDiffEq/6mZB8/src/solve.jl:6
 [32] #__solve at ./none:0 [inlined] (repeats 5 times)
 [33] #solve#425(::Base.Iterators.Pairs{Symbol,Float64,Tuple{Symbol},NamedTuple{(:saveat,),Tuple{Float64}}}, ::Function, ::ODEProblem{CuArray{Float32,1},Tuple{Float32,Float32},true,CuArray{Float32,1},ODEFunction{true,getfield(DiffEqFlux, Symbol("#dudt_#24")){Chain{Tuple{getfield(Main, Symbol("##3#4")),Dense{typeof(tanh),TrackedArray{…,CuArray{Float32,2}},TrackedArray{…,CuArray{Float32,1}}},Dense{typeof(identity),TrackedArray{…,CuArray{Float32,2}},TrackedArray{…,CuArray{Float32,1}}}}}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing,DiffEqBase.StandardODEProblem}, ::BS3) at /home/brookhart/.julia/packages/DiffEqBase/PvfXM/src/solve.jl:39
 [34] #solve at ./none:0 [inlined]
 [35] #_forward#17 at /home/brookhart/.julia/packages/DiffEqFlux/1w1tX/src/Flux/layers.jl:54 [inlined]
 [36] #_forward at ./none:0 [inlined]
 [37] #track#1 at /home/brookhart/.julia/packages/Flux/8XpDt/src/tracker/Tracker.jl:51 [inlined]
 [38] #track at ./none:0 [inlined]
 [39] #diffeq_adjoint#16 at /home/brookhart/.julia/packages/DiffEqFlux/1w1tX/src/Flux/layers.jl:50 [inlined]
 [40] (::getfield(DiffEqFlux, Symbol("#kw##diffeq_adjoint")))(::NamedTuple{(:saveat,),Tuple{Float64}}, ::typeof(diffeq_adjoint), ::TrackedArray{…,CuArray{Float32,1}}, ::ODEProblem{CuArray{Float32,1},Tuple{Float32,Float32},true,TrackedArray{…,CuArray{Float32,1}},ODEFunction{true,getfield(DiffEqFlux, Symbol("#dudt_#24")){Chain{Tuple{getfield(Main, Symbol("##3#4")),Dense{typeof(tanh),TrackedArray{…,CuArray{Float32,2}},TrackedArray{…,CuArray{Float32,1}}},Dense{typeof(identity),TrackedArray{…,CuArray{Float32,2}},TrackedArray{…,CuArray{Float32,1}}}}}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing,DiffEqBase.StandardODEProblem}, ::BS3) at ./none:0
 [41] #neural_ode#23(::Base.Iterators.Pairs{Symbol,Float64,Tuple{Symbol},NamedTuple{(:saveat,),Tuple{Float64}}}, ::Function, ::Chain{Tuple{getfield(Main, Symbol("##3#4")),Dense{typeof(tanh),TrackedArray{…,CuArray{Float32,2}},TrackedArray{…,CuArray{Float32,1}}},Dense{typeof(identity),TrackedArray{…,CuArray{Float32,2}},TrackedArray{…,CuArray{Float32,1}}}}}, ::CuArray{Float32,1}, ::Tuple{Float32,Float32}, ::BS3) at /home/brookhart/.julia/packages/DiffEqFlux/1w1tX/src/Flux/neural_de.jl:8
 [42] (::getfield(DiffEqFlux, Symbol("#kw##neural_ode")))(::NamedTuple{(:saveat,),Tuple{Float64}}, ::typeof(neural_ode), ::Chain{Tuple{getfield(Main, Symbol("##3#4")),Dense{typeof(tanh),TrackedArray{…,CuArray{Float32,2}},TrackedArray{…,CuArray{Float32,1}}},Dense{typeof(identity),TrackedArray{…,CuArray{Float32,2}},TrackedArray{…,CuArray{Float32,1}}}}}, ::CuArray{Float32,1}, ::Tuple{Float32,Float32}, ::BS3) at ./none:0
 [43] (::getfield(Main, Symbol("##5#6")))(::Array{Float32,1}) at /home/brookhart/Documents/Kaggle/titanic/kernel.jl:20
 [44] top-level scope at none:0
 [45] include at ./boot.jl:326 [inlined]
 [46] include_relative(::Module, ::String) at ./loading.jl:1038
 [47] include(::Module, ::String) at ./sysimg.jl:29
 [48] exec_options(::Base.JLOptions) at ./client.jl:267
 [49] _start() at ./client.jl:43

This matches the error seen here: https://github.com/JuliaGPU/CuArrays.jl/issues/161

I'm using Julia 1.1, latest releases of all packages, CUDA 10 on a GTX 1070Ti on Ubuntu 18.04

Thanks!

ChrisRackauckas commented 5 years ago

Yeah, after the blog post, people had matrices of parameters so we added a reshape in there to handle the different sizes. It looks like CuArrays can't handle reshaping a vector to a vector, and so that is the upstream issue you're pointing to.

ChrisRackauckas commented 5 years ago

This should be fixed on the master branch of CuArrays.jl. Check that out and see if there's any issues.

mbrookhart commented 5 years ago

Thanks! I'll check it tonight.

mbrookhart commented 5 years ago

Thanks! I had to update CUDAdrv and CUDAnative as well, but that got me past the original error. Now I'm seeing a different error in the Adam update of flux, but I wonder if it's just a version mismatch somewhere.

ChrisRackauckas commented 5 years ago

DiffEq had a breaking v6.0 release and we're looking at this in the next few days to make sure it gets updated.