Closed cems2 closed 4 years ago
I think you can define an in-place problem function to make this work (rather a workaround to not hit this issue)
using Flux
using OrdinaryDiffEq
using DiffEqFlux
model1 = FastDense(2,2)
model2 = FastChain(model1)
#initial conditions
u0 = Float32[1.0,1.0]
u1 = Float32[1.5,1.5]
u2 = Float32[2.0,2.0]
#minibatch = [(u0[:,:],)] # note this is 2x1. this matters to error
minibatch = [([u0 u1 u2],)]
# test both models
function do_it()
for model in (model1,model2)
println("\nmodel ",model)
p = initial_params(model)
#set up ODE Problem
#f(x,p,t) = model(x,p) # error tracebakc juno highlits this in red
function f(dx,x,p,t)
dx .= model(x,p)
nothing
end
prob = ODEProblem(f,minibatch[1][1],(0.0f0,1.0f0),p) # u0 is a placeholder.
pred(u,p) = concrete_solve(prob,Tsit5(),u,p, saveat=0.01)
#validate prediction working
pred(u0,p)
pred(minibatch[1][1],p)
# set up loss function
# setup loss function
loss_vec_batch(p,mb) = sum(abs, pred(mb,p).- mb) # vectorized over batch
loss_vec(p) = loss_vec_batch(p,minibatch[1]...)
loss = loss_vec # rename
println(loss(p))
#validate Loss Function working
loss_vec_batch(p,minibatch[1]...) == loss_vec(p)
# set up the gradient parameters
ps = Flux.params(p)
#use zygote on vector case
gsv = Flux.Zygote.gradient(ps) do
x = loss(p)
first(x)
end
println("\nMODEL RAN CLEANLY. ",gsv.grads)
end
end
do_it()
julia> do_it()
model FastDense{typeof(identity),DiffEqFlux.var"#initial_params#48"{typeof(Flux.glorot_uniform),typeof(Flux.zeros),Int64,Int64}}(2, 2, identity, DiffEqFlux.var"#initial_params#48"{typeof(Flux.glorot_uniform),typeof(Flux.zeros),Int64,Int64}(Flux.glorot_uniform, Flux.zeros, 2, 2))
76.2751
MODEL RAN CLEANLY. IdDict{Any,Any}(Float32[0.7815552, -0.026053874, -0.60071194, -0.05056247, 0.0, 0.0] => Float32[325.44238, -300.28796, 296.0859, -272.6577, 202.24295, -186.32872])
model FastChain{Tuple{FastDense{typeof(identity),DiffEqFlux.var"#initial_params#48"{typeof(Flux.glorot_uniform),typeof(Flux.zeros),Int64,Int64}}}}((FastDense{typeof(identity),DiffEqFlux.var"#initial_params#48"{typeof(Flux.glorot_uniform),typeof(Flux.zeros),Int64,Int64}}(2, 2, identity, DiffEqFlux.var"#initial_params#48"{typeof(Flux.glorot_uniform),typeof(Flux.zeros),Int64,Int64}(Flux.glorot_uniform, Flux.zeros, 2, 2)),))
251.10022
MODEL RAN CLEANLY. IdDict{Any,Any}(Float32[0.41792107, 0.47331968, 0.24224015, -0.25262898, 0.0, 0.0] => Float32[374.25388, 288.6827, 330.43677, 253.09106, 202.13618, 153.87003])
three questions:
And finally thank you for taking a look and finding this work around.
You shouldn't have to transform this to in-place. I'll look into this today.
Out of curiosity how does ODEProblem tell if the function handed in is a 3 or 4 parameter function (return value or inplace result). Introspection? Testing return values? how?
We keep a sorcerer who specializes in dark magic on retainer for whenever someone defines a DEProblem
. Though he recently digitized himself into method table inspection code:
https://github.com/JuliaDiffEq/DiffEqBase.jl/blob/master/src/utils.jl#L1-L43
This problem is the same as https://github.com/JuliaDiffEq/DiffEqFlux.jl/issues/160
it is? they both do involve Zygote.gradient but in the case of 171 it seems like somehow FastChain() is an important factor.
On Mar 3, 2020, at 10:01 AM, Christopher Rackauckas notifications@github.com wrote:
This problem is the same as #160 https://github.com/JuliaDiffEq/DiffEqFlux.jl/issues/160 — You are receiving this because you authored the thread. Reply to this email directly, view it on GitHub https://github.com/JuliaDiffEq/DiffEqFlux.jl/issues/171?email_source=notifications&email_token=ACRAR7WM7QTTFKWBVSPEHQTRFUZX5A5CNFSM4KYRKPOKYY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGOENUJUIQ#issuecomment-594057762, or unsubscribe https://github.com/notifications/unsubscribe-auth/ACRAR7XZM4VXSEIXWOPTWTTRFUZX5ANCNFSM4KYRKPOA.
Because of how the output is handled.
I think that the error from this code is also caused by the same issue.
using DiffEqFlux, OrdinaryDiffEq, Flux, Optim, DiffEqSensitivity
using Flux, Flux.Data.MNIST
using Flux: onehotbatch, onecold, crossentropy, throttle
using Base.Iterators: repeated
# testing on MNIST
# Classify MNIST digits
imgs = MNIST.images()
# Stack images into one large batch
X = hcat(float.(reshape.(imgs, :))...)
X = convert(Array{Float32,2}, X)
labels = MNIST.labels()
# One-hot-encode the labels
Y = onehotbatch(labels, 0:9)
import NNlib.softmax
softmax(x::AbstractArray, p::AbstractArray) = softmax(x)
N = 32
model = FastChain(FastDense(28^2,N,tanh),
FastDense(N,N,tanh),
FastDense(N,10),
softmax)
p0=initial_params(model)
model(X[:,1:100], p0)
dataset = repeated((X, Y), 5)
function loss(p,x,y)
pred = model(x,p)
l = crossentropy(pred, y)
l, pred
end
for mb in dataset
display(loss(p0,mb...)[1])
end
cb = function (p,l,pred)
display(l)
return false
end
res1 = DiffEqFlux.sciml_train(loss, p0, ADAM(), dataset, cb = cb, maxiters = 100)
Which throws the output:
ERROR: ArgumentError: number of columns of each array must match (got (1, 60000))
Stacktrace:
[1] _typed_vcat(::Type{Float32}, ::Tuple{Array{Float32,1},Array{Float32,2}}) at ./abstractarray.jl:1359
[2] typed_vcat at ./abstractarray.jl:1373 [inlined]
[3] vcat at /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.3/SparseArrays/src/sparsevector.jl:1079 [inlined]
[4] (::DiffEqFlux.var"#FastDense_adjoint#49"{FastDense{typeof(identity),DiffEqFlux.var"#initial_params#48"{typeof(Flux.glorot_uniform),typeof(Flux.zeros),Int64,Int64}},Array{Float32,2},Array{Float32,2},Array{Float32,2},Array{Float32,2}})(::Array{Float32,2}) at /Users/ger/.julia/packages/DiffEqFlux/dZmLq/src/fast_layers.jl:53
[5] #167#back at /Users/ger/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49 [inlined]
[6] applychain at /Users/ger/.julia/packages/DiffEqFlux/dZmLq/src/fast_layers.jl:20 [inlined]
[7] (::typeof(∂(applychain)))(::Array{Float32,2}) at /Users/ger/.julia/packages/Zygote/tJj2w/src/compiler/interface2.jl:0
[8] applychain at /Users/ger/.julia/packages/DiffEqFlux/dZmLq/src/fast_layers.jl:20 [inlined]
[9] (::typeof(∂(applychain)))(::Array{Float32,2}) at /Users/ger/.julia/packages/Zygote/tJj2w/src/compiler/interface2.jl:0
[10] applychain at /Users/ger/.julia/packages/DiffEqFlux/dZmLq/src/fast_layers.jl:20 [inlined]
[11] (::typeof(∂(applychain)))(::Array{Float32,2}) at /Users/ger/.julia/packages/Zygote/tJj2w/src/compiler/interface2.jl:0
[12] FastChain at /Users/ger/.julia/packages/DiffEqFlux/dZmLq/src/fast_layers.jl:21 [inlined]
[13] (::typeof(∂(λ)))(::Array{Float32,2}) at /Users/ger/.julia/packages/Zygote/tJj2w/src/compiler/interface2.jl:0
[14] loss at /Users/ger/Documents/Issues/for_posting_minibatch.jl:32 [inlined]
[15] (::typeof(∂(loss)))(::Tuple{Float32,Nothing}) at /Users/ger/.julia/packages/Zygote/tJj2w/src/compiler/interface2.jl:0
[16] #165 at /Users/ger/.julia/packages/Zygote/tJj2w/src/lib/lib.jl:156 [inlined]
[17] (::Zygote.var"#321#back#167"{Zygote.var"#165#166"{typeof(∂(loss)),Tuple{Tuple{Nothing},Tuple{Nothing,Nothing}}}})(::Tuple{Float32,Nothing}) at /Users/ger/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
[18] #18 at /Users/ger/.julia/packages/DiffEqFlux/dZmLq/src/train.jl:50 [inlined]
[19] (::typeof(∂(λ)))(::Float32) at /Users/ger/.julia/packages/Zygote/tJj2w/src/compiler/interface2.jl:0
[20] (::Zygote.var"#46#47"{Zygote.Params,Zygote.Context,typeof(∂(λ))})(::Float32) at /Users/ger/.julia/packages/Zygote/tJj2w/src/compiler/interface.jl:101
[21] gradient(::Function, ::Zygote.Params) at /Users/ger/.julia/packages/Zygote/tJj2w/src/compiler/interface.jl:47
[22] macro expansion at /Users/ger/.julia/packages/DiffEqFlux/dZmLq/src/train.jl:49 [inlined]
[23] macro expansion at /Users/ger/.julia/packages/Juno/oLB1d/src/progress.jl:119 [inlined]
[24] #sciml_train#16(::Function, ::Int64, ::typeof(DiffEqFlux.sciml_train), ::Function, ::Array{Float32,1}, ::ADAM, ::Base.Iterators.Take{Base.Iterators.Repeated{Tuple{Array{Float32,2},Flux.OneHotMatrix{Array{Flux.OneHotVector,1}}}}}) at /Users/ger/.julia/packages/DiffEqFlux/dZmLq/src/train.jl:48
[25] (::DiffEqFlux.var"#kw##sciml_train")(::NamedTuple{(:cb, :maxiters),Tuple{var"#15#16",Int64}}, ::typeof(DiffEqFlux.sciml_train), ::Function, ::Array{Float32,1}, ::ADAM, ::Base.Iterators.Take{Base.Iterators.Repeated{Tuple{Array{Float32,2},Flux.OneHotMatrix{Array{Flux.OneHotVector,1}}}}}) at ./none:0
[26] top-level scope at none:0
I tried also with the minibatch way suggested by #167 and it throws the same error.
I hope this different example (although maybe not MWE 😅) could be somewhat useful.
Closing duplicate issue. It all is boiling down to that DiffEqArray adjoint.
That issue was fixed with a fix to the adjoints.
Minibatching is being tracked a few other places as a tutorial update.
In a NeuralODE problem I am testing two trivial models
I believe these are nominally identical models. Yet when called in a batch mode, the Zygote Gradient of my loss function fails for the second model but not the first. The failure ONLY occurs when broadcasting over a batch.
code to reproduce:
Analysis The results of running this code are pasted in below. What happens is model1 runs cleanly and model2 has an error in the Zygote.gradient.
This is the MWE because there are two critical things required to trigger this error.
By batch orientation I mean: If you only wanted to run one intial condition at a time in the concrete_solver then the initial condition u0 can be a simple vector. But if you wanted to run multiple initial conditions at the same time then the initial condition is a 2D matrix (a list of initial condition vectors).
In this example, I am using the matrix format-- as though it were a batch-- but only doing it for one initial condition. The error is the same if I tack on additional initial conditions to this matrix. I left it at one to simplify this. The conversion step from a vector to a matrix is u0[:,:] and if you want add more initial conditions in batch orientation:
minibatch = [([u0 u1 u2],)]
This error does not happen if the loss function loops over the initial conditions solving them one at a time as simple vectors. Thus this is a broadcasting problem.
Regression in the code you can see each component of the process is validated. that is the model, the ode solver, and the loss functions all work fine in this batch oriented input. The error happens in the gradient step. It only happens on the second model, not the first.
Output
Versions