update: The crux: found the single line change that converts this from working to non-working!!! This is not the fix-- it just pinpoints what makes Zygote fail.
The batch of k items to train on is specified as a 2D matrix of dimensions (k x 2)
# contains batch of 4 items [u0 u1 u2 u3] is a 2D array of shape 2×4
minibatch = [([u0 u1 u2 u3],)]
When I consider batches containing just a single item I have the option to write these as either as a 2D matrix (1x2) as above or to use a 1D column vector.
The column vector works in sciml_train but the 2D matrix fails!!!
# contains batch of 1 items as a 2D array of shape 2×1 Array
minibatch = [(reshape(u0,2,1),)] # this won't work in sciml_train
#contains a batch of 1 item as 1D array of size 2
minibatch = [(u0,)] # this works in sciml_train!!
Note the dimensions and shape are all correct for the loss function. The issue is Zygote fails when taking the gradient in the 2D case.
Confirming this with a related observation: converting this to serial processing of the batch instead of vector will let it work, since now one is processing it as column vectors again.
# vector version of loss fails in sciml_train
loss_vector(p,mb) = sum(abs, pred(mb,p).- mb)
where mb is a 2D array listing a set of initial condition vectors.
Changing the vector version to an explicit loop does work in sciml_train
function loss_serial(p,mb)
temp=0.0f0
for i in 1:size(mb)[2] # explicitly loop over each initial condition
j = mb[:,i]
temp += sum(abs, pred(j,p).- j)
end
temp
end
key point: BOTH of these functions work just fine when called manually and also during the initial calls made by sciml_train to compute the loss. The vector version fails in sciml_train only when it comes time for zygote to take the gradient.
_1. The serial version works with sciml_train
The vector version works in sciml_train only if the batch is just a single item in a column vector
vector version fails in scimltrain when batch has multiple training items in a 2D matrix.
so it's something about the 2D shape of the Batch Data that cause the Zygote gradient problem.
=============
Full description of the problem
=============
Problem Area:
batching: the loss function will consider multiple training examples at once.
I have two ways to do this:
serial looping over calls to the ODE solver on each initial condition in a supplied list.
Calling ODE solver on the entire matrix. So it's processing all the initial conditions simultaneously. (that works fine)
however sciml_train fails to work with the second loss method ("vectorized")
Why I think the error is related to zygote/siml_train()
When I add print statements before the returns from loss() and the model f(x,p,t) to see if it returns I see that, yes indeed, the loss function returns normally.
However, after the loss function returns, one observes a further call to the model f(x,p,t). This call also finishes cleanly (returning no Nans, right dimensions).
Then the error happens after the return from f(x.p,t).
_Since the only function in my code that calls f(x,p,t) is in the Loss Function it logically can't be called that final time after the loss function returns. So this has to be the work of Zygote computing the derivatives I think._
Regression
tests:
when called manually the model, the prediction, and the loss function all work fine in both single and multiple initial condition batches
these functions work when used in serial or vector mode loss functions.
the serial and vector versions agree numerically
the numerical output when called by sciml_train agrees with the manual calls.
Also none of the input arrays are changing dimensions, but the zygote error message complains of a dimension mismatch.
The error is not being raised from commands in my code, but is raised from within Zygote.
Strawman rejected:
The only thing that I think is weird in my program here the setup of prob (the ODE problem).
Notably ODEProblem requires you to input an initial condition, but later this intial condition is overridden by concrete solve, so it's just a placeholder. (why do we need it? I'm guessing for solve()? would be better to have a prototype signature without this uneeded term)
However I tested this by moving the ODE problem set up inside the Loss loop. I get the same error. So that's not it.
Code to reproduce:
#=
Status `~/.julia/environments/v1.3/Project.toml`
[c52e3926] Atom v0.12.3
[aae7a2af] DiffEqFlux v1.3.2 #master (https://github.com/JuliaDiffEq/DiffEqFlux.jl.git)
[0c46a032] DifferentialEquations v6.11.0
[587475ba] Flux v0.10.1
[7073ff75] IJulia v1.21.1
[e5e0dc1b] Juno v0.7.2
[429524aa] Optim v0.20.1
[1dea7af3] OrdinaryDiffEq v5.29.0
[91a5bcdd] Plots v0.29.1
[d330b81b] PyPlot v2.8.2
=#
using Flux #for ADAM #DiffEqFlux,
using Optim # for BFGS
using OrdinaryDiffEq
using DiffEqFlux
model1 = FastChain(FastDense(2,2))
p = initial_params(model1)
function f(x,p,t)
println("fx ",x)
println("fp ",p)
model1(x,p) # error tracebakc juno highlits this in red
end
u0 = Float32[1.0,1.0]
u1 = Float32[0.0,1.0]
u2 = Float32[1.0,0.0]
u3 = Float32[1.0,0.7]
#minibatch = [([u0 u1 u2 u3],)] # mutiple training cases 2D matrix, gives error
minibatch = [(u0[:,:],)] # single training case 2D matrix , gives error
#minibatch = [(u0,)] # Single Training case 1D matrix ####### this won't give an error
prob = ODEProblem(f,minibatch[1][1],(0.0f0,1.0f0),p) # u0 is a placeholder.
pred(u,p) = concrete_solve(prob,Tsit5(),u,p, saveat=0.01)
#validate
pred(u0,p)
pred(minibatch[1][1],p)
function loss_batch(p,mb)
#prob = ODEProblem(f,mb,(0.0f0,1.0f0),p) # moving these two lines inside the loop
#pred(u,p) = concrete_solve(prob,Tsit5(),u,p, saveat=0.01) # but this doesn't fix the problem
println("==============================")
println("mb",mb)
println("p",p)
sum(abs, pred(mb,p).- mb)
end
#validate
loss_batch(p,minibatch[1]...)
function cb1(args...)
println("args:",args[1:2])
false
end
res0 = DiffEqFlux.sciml_train(loss_batch,p,ADAM(0.005),minibatch,maxiters=300,cb=cb1)
But Running this manually on the exact same numerical data works fine!: Here I show it sending f() the last data f received before the error. And I show it calling the loss function with the initial data that caused the error. Both run clean when called manually.
Modifed program to be less verbose and report leaving functions
julia> loss_batch(Float32[0.47264105, 0.15957993, -0.8247262, -0.9430232, 0.0, 0.0],Float32[1.0 0.0 1.0 1.0; 1.0 1.0 0.0 0.7])
==============================
p,mbFloat32[0.47264105, 0.15957993, -0.8247262, -0.9430232, 0.0, 0.0], Float32[1.0 0.0 1.0 1.0; 1.0 1.0 0.0 0.7]
function f returning nowFloat32[-0.35208517 -0.8247262 0.47264105 -0.1046673; -0.7834433 -0.9430232 0.15957993 -0.5005363]
function f returning nowFloat32[-0.35208517 -0.8247262 0.47264105 -0.1046673; -0.7834433 -0.9430232 0.15957993 -0.5005363]
function f returning nowFloat32[-0.3520712 -0.8247149 0.47264367 -0.10465668; -0.78342336 -0.94300115 0.15957774 -0.50052303]
function f returning nowFloat32[-0.3518602 -0.82454425 0.4726841 -0.10449693; -0.78312314 -0.94266784 0.15954472 -0.50032276]
function f returning nowFloat32[-0.35162842 -0.8243569 0.47272855 -0.10432134; -0.78279334 -0.9423018 0.15950848 -0.5001028]
function f returning nowFloat32[-0.35082868 -0.82371074 0.47288203 -0.10371547; -0.78165567 -0.94103914 0.15938345 -0.4993439]
function f returning nowFloat32[-0.35071707 -0.8236206 0.47290346 -0.103630885; -0.7814969 -0.9408629 0.159366 -0.49923798]
function f returning nowFloat32[-0.35068923 -0.8235981 0.47290882 -0.10360981; -0.7814573 -0.9408189 0.15936165 -0.49921158]
function f returning nowFloat32[-0.35068923 -0.8235981 0.47290882 -0.10360981; -0.7814573 -0.94081897 0.15936165 -0.49921158]
function f returning nowFloat32[-0.3501717 -0.8231801 0.4730084 -0.10321762; -0.7807211 -0.94000185 0.15928076 -0.4987205]
function f returning nowFloat32[-0.34963885 -0.82275015 0.4731113 -0.102813795; -0.7799634 -0.93916094 0.15919757 -0.49821508]
function f returning nowFloat32[-0.34780204 -0.8212693 0.47346726 -0.10142117; -0.77735204 -0.9362631 0.15891102 -0.49647304]
function f returning nowFloat32[-0.34754583 -0.821063 0.47351712 -0.10122694; -0.77698797 -0.9358591 0.15887111 -0.4962302]
function f returning nowFloat32[-0.34748188 -0.8210115 0.47352958 -0.10117844; -0.7768971 -0.93575823 0.15886116 -0.4961696]
function f returning nowFloat32[-0.34748197 -0.8210116 0.4735296 -0.10117851; -0.77689725 -0.9357585 0.15886118 -0.4961697]
function f returning nowFloat32[-0.34629723 -0.82005763 0.47376028 -0.10028001; -0.7752135 -0.93389016 0.15867656 -0.4950465]
function f returning nowFloat32[-0.34508005 -0.8190796 0.47399956 -0.09935615; -0.77348477 -0.93197215 0.15848735 -0.49389312]
function f returning nowFloat32[-0.340891 -0.8157204 0.4748294 -0.09617487; -0.7675388 -0.9253761 0.15783736 -0.4899259]
function f returning nowFloat32[-0.3403078 -0.81525373 0.47494593 -0.09573173; -0.7667115 -0.9244586 0.15774706 -0.48937395]
function f returning nowFloat32[-0.3401623 -0.8151374 0.47497502 -0.09562114; -0.7665051 -0.9242298 0.15772454 -0.48923624]
function f returning nowFloat32[-0.34016293 -0.8151381 0.47497514 -0.09562151; -0.76650614 -0.9242309 0.15772468 -0.4892369]
function f returning nowFloat32[-0.33721623 -0.81278163 0.4755653 -0.09338177; -0.7623269 -0.9195957 0.15726873 -0.4864482]
function f returning nowFloat32[-0.3342048 -0.81038725 0.47618237 -0.09108869; -0.75806314 -0.91486865 0.15680543 -0.4836026]
function f returning nowFloat32[-0.323886 -0.80222374 0.47833765 -0.08321894; -0.74347454 -0.8987005 0.15522581 -0.47386447]
function f returning nowFloat32[-0.32245553 -0.80109805 0.4786425 -0.0821261; -0.7414553 -0.8964634 0.15500802 -0.4725163]
function f returning nowFloat32[-0.32209903 -0.8008178 0.47871876 -0.08185372; -0.74095225 -0.8959061 0.1549538 -0.47218043]
function f returning nowFloat32[-0.3221053 -0.80082506 0.4787197 -0.08185779; -0.74096227 -0.8959176 0.15495518 -0.47218704]
function f returning nowFloat32[-0.31558764 -0.7957061 0.48011833 -0.07687585; -0.7317673 -0.88573205 0.15396468 -0.46604767]
function f returning nowFloat32[-0.30900088 -0.7906052 0.48160422 -0.07181938; -0.72251254 -0.8754903 0.15297769 -0.45986548]
function f returning nowFloat32[-0.28661725 -0.7734787 0.48686138 -0.054573644; -0.6911712 -0.8408351 0.14966382 -0.4389207]
function f returning nowFloat32[-0.28353238 -0.7711471 0.48761454 -0.052188344; -0.6868668 -0.8360796 0.14921263 -0.436043]
function f returning nowFloat32[-0.28276563 -0.7705691 0.4878035 -0.051594906; -0.6857978 -0.83489865 0.14910083 -0.43532822]
function f returning nowFloat32[-0.2828238 -0.7706339 0.48781 -0.05163373; -0.68588984 -0.8350034 0.14911336 -0.43538892]
function f returning nowFloat32[-0.270896 -0.76167643 0.4907804 -0.04239313; -0.669277 -0.81665725 0.14738014 -0.4242799]
function f returning nowFloat32[-0.25906157 -0.75306773 0.4940061 -0.03314135; -0.65294003 -0.7986545 0.14571442 -0.4133437]
function f returning nowFloat32[-0.21927987 -0.72489256 0.50561273 -0.001812116; -0.5984218 -0.7386845 0.14026265 -0.37681645]
function f returning nowFloat32[-0.21379752 -0.72109854 0.507301 0.0025320007; -0.5909551 -0.73048365 0.13952854 -0.37181]
function f returning nowFloat32[-0.21244155 -0.7201674 0.5077259 0.003608575; -0.5891121 -0.72846043 0.13934837 -0.37057403]
function f returning nowFloat32[-0.2128105 -0.72056866 0.5077581 0.0033600465; -0.5896908 -0.729117 0.13942602 -0.37095577]
function f returning nowFloat32[-0.19394788 -0.70781827 0.51387036 0.018397572; -0.5641594 -0.70111835 0.1369589 -0.35382387]
function f returning nowFloat32[-0.17573652 -0.6964107 0.5206741 0.03318669; -0.5399815 -0.67473316 0.1347516 -0.33756152]
function f returning nowFloat32[-0.11501745 -0.6606294 0.54561204 0.083171405; -0.4605477 -0.5883757 0.12782812 -0.2840349]
function f returning nowFloat32[-0.10641288 -0.6557103 0.54929745 0.09030023; -0.44937027 -0.5762465 0.12687628 -0.27649626]
function f returning nowFloat32[-0.104304306 -0.6545324 0.550228 0.092055336; -0.44664562 -0.573294 0.12664832 -0.27465746]
function f returning nowFloat32[-0.10605597 -0.656397 0.5503409 0.0908631; -0.44937247 -0.57638174 0.12700915 -0.27645794]
function f returning nowFloat32[-0.08682744 -0.6464902 0.55966264 0.10711958; -0.42496237 -0.5500548 0.12509225 -0.25994596]
function f returning nowFloat32[-0.06838242 -0.63840485 0.57002234 0.12313901; -0.4022885 -0.5258164 0.12352779 -0.24454355]
function f returning nowFloat32[-0.006575223 -0.6146271 0.6080518 0.17781283; -0.32804474 -0.4469718 0.11892693 -0.19395325]
function f returning nowFloat32[0.0024408372 -0.6112648 0.61370534 0.18582001; -0.31727004 -0.43554673 0.11827634 -0.18660627]
function f returning nowFloat32[0.0046428894 -0.6104887 0.61513144 0.18778938; -0.31466216 -0.43278855 0.11812618 -0.18482572]
function f returning nowFloat32[0.002099994 -0.61310846 0.6152084 0.18603249; -0.31857523 -0.4372085 0.11863309 -0.1874128]
loss_batch returning now168.21149
+++++++++++++++++
168.21149f0
julia>
Appendix:Update I changed the title of this because the error does not require using "optional Data". If you remove the Optional_data, and then slide the minibatch into the loss function by hard coding a closure, the vectorized version still does not work work
code changes to alter this from using the optional data in sciml_train to having the loss function take care of the data:
# this change remove optional data method
loss_vec(p,mb) = sum(abs, pred(mb,p).- mb)
loss(p) = loss_vec(p,minibatch[1]...) # closure to embed data
# not using optional data in sciml_train
res0 = DiffEqFlux.sciml_train(loss,p,ADAM(0.005),maxiters=300,cb=cb1)
Again the vector method does not work, the serial method does.
update: The crux: found the single line change that converts this from working to non-working!!! This is not the fix-- it just pinpoints what makes Zygote fail.
The batch of k items to train on is specified as a 2D matrix of dimensions (k x 2)
When I consider batches containing just a single item I have the option to write these as either as a 2D matrix (1x2) as above or to use a 1D column vector.
The column vector works in sciml_train but the 2D matrix fails!!!
Note the dimensions and shape are all correct for the loss function. The issue is Zygote fails when taking the gradient in the 2D case.
Confirming this with a related observation: converting this to serial processing of the batch instead of vector will let it work, since now one is processing it as column vectors again.
where mb is a 2D array listing a set of initial condition vectors.
Changing the vector version to an explicit loop does work in sciml_train
key point: BOTH of these functions work just fine when called manually and also during the initial calls made by sciml_train to compute the loss. The vector version fails in sciml_train only when it comes time for zygote to take the gradient.
_1. The serial version works with sciml_train
so it's something about the 2D shape of the Batch Data that cause the Zygote gradient problem.
============= Full description of the problem =============
Problem Area: batching: the loss function will consider multiple training examples at once. I have two ways to do this:
however sciml_train fails to work with the second loss method ("vectorized")
Why I think the error is related to zygote/siml_train() When I add print statements before the returns from loss() and the model f(x,p,t) to see if it returns I see that, yes indeed, the loss function returns normally.
However, after the loss function returns, one observes a further call to the model f(x,p,t). This call also finishes cleanly (returning no Nans, right dimensions).
Then the error happens after the return from f(x.p,t).
_Since the only function in my code that calls f(x,p,t) is in the Loss Function it logically can't be called that final time after the loss function returns. So this has to be the work of Zygote computing the derivatives I think._
Regression tests:
Also none of the input arrays are changing dimensions, but the zygote error message complains of a dimension mismatch.
The error is not being raised from commands in my code, but is raised from within Zygote.
Strawman rejected: The only thing that I think is weird in my program here the setup of prob (the ODE problem). Notably ODEProblem requires you to input an initial condition, but later this intial condition is overridden by concrete solve, so it's just a placeholder. (why do we need it? I'm guessing for solve()? would be better to have a prototype signature without this uneeded term)
However I tested this by moving the ODE problem set up inside the Loss loop. I get the same error. So that's not it.
Code to reproduce:
output:
But Running this manually on the exact same numerical data works fine!: Here I show it sending f() the last data f received before the error. And I show it calling the loss function with the initial data that caused the error. Both run clean when called manually.
Modifed program to be less verbose and report leaving functions
Appendix: Update I changed the title of this because the error does not require using "optional Data". If you remove the Optional_data, and then slide the minibatch into the loss function by hard coding a closure, the vectorized version still does not work work
code changes to alter this from using the optional data in sciml_train to having the loss function take care of the data:
Again the vector method does not work, the serial method does.