Closed ChrisRackauckas closed 3 years ago
@DhairyaLGandhi I think the map
adjoint doesn't correctly ignore nothing
's going backwards, could you take a look at this?
I have a similar issue as this thread in DiffEqFlux:
[https://github.com/SciML/DiffEqFlux.jl/issues/381]
with something resembling the MWE in this thread. Was any progress made in the past months?
Thank you
Could we test with https://github.com/FluxML/Zygote.jl/pull/846 ?
That branch doesn't seem to help. In fact, I'm a bit puzzled and made another similar example to work with first:
using OrdinaryDiffEq, DiffEqSensitivity, Flux, DiffEqGPU, StaticArrays, CUDA
CUDA.allowscalar(false)
function model()
prob = ODEProblem((du, u, p, t) -> du[1] = 1.01 * u[1] * p[1] * p[2], u0, (0.0, 1.0), pa)
function prob_func(prob, i, repeat)
prob
end
ensemble_prob = EnsembleProblem(prob, prob_func = prob_func)
solve(ensemble_prob, Tsit5(), EnsembleGPUArray(0.0), saveat = 0.1, trajectories = 10, sensealg = ForwardDiffSensitivity(convert_tspan=false))
end
# loss function
loss() = sum(abs2,1.0.-Array(model()))
data = Iterators.repeated((), 10)
cb = function () # callback function to observe training
@show loss()
end
pa = [1.0,2.0]
u0 = [3.0]
opt = ADAM(0.1)
println("Starting to train")
loss()
Flux.@epochs 10 Flux.train!(loss, params([pa]), data, opt; cb = cb)
In the adjoint I specify, i.e. ZygoteRules.@adjoint function batch_solve_up(ensembleprob,probs,alg,ensemblealg,I,u0,p;kwargs...)
, I have that:
(size(Array(VectorOfArray(adj))), size(p)) = ((2, 10), (2, 10))
So I know that what I'm pulling back is the same size as p
(correct? I assume Zygote doesn't do something crazy on matrices?). You would think that's working, but it gets all the way back to the Flux update!
code where it sees
(x, gs[x]) = ([1.0, 2.0], [46839.635021615635; 23419.817510807818])
saying that the derivative somehow adjointed on its own... what?
oh wait, remembering that Zygote's adjoints for comprehensions are incorrect I got rid of the comprehensions. See that last commit. That's all I needed to fix that issue. So I think comprehensions incorrectly transpose variables behind pulled back. @DhairyaLGandhi you might want to take a look at that today and try to find a smaller reproducer since that is an issue that keeps coming up.
The error isn't reproducible so I'm just going to merge, but @vchuravy it would be good to know why KernelAbstractions.jl cannot compile sometimes, and where it decides it can't is seemingly random, dependent on the computer, how many functions were ran before it, and just how many times a code has been ran. I don't remember it being unstable like that.
Seems like the test issue was just changing inbounds semantics between different environments.
Im having trouble reproducing the issue, I see you've gotten rid of comprehensions but is there a more minimal example that I can use?
There isn't a more minimal example I could find.
Fixes https://github.com/SciML/DiffEqFlux.jl/issues/381 . MWE: