SciML / DiffEqGPU.jl

GPU-acceleration routines for DifferentialEquations.jl and the broader SciML scientific machine learning ecosystem
MIT License
283 stars 29 forks source link

define adjoint #72

Closed ChrisRackauckas closed 3 years ago

ChrisRackauckas commented 4 years ago

Fixes . MWE:

using OrdinaryDiffEq, DiffEqSensitivity, Flux, DiffEqGPU, StaticArrays, CUDA

function model()
  prob = ODEProblem((du, u, p, t) -> du[1] = 1.01 * u[1] * p[1], u0, (0.0, 1.0), pa)

  function prob_func(prob, i, repeat)
    remake(prob, u0 = 0.5 .+ i/100 .* prob.u0)

  ensemble_prob = EnsembleProblem(prob, prob_func = prob_func)
  solve(ensemble_prob, Tsit5(), EnsembleGPUArray(), saveat = 0.1, trajectories = 10, sensealg = ForwardDiffSensitivity(convert_tspan=false))

# loss function
loss() = sum(abs2,1.0.-Array(model()))

data = Iterators.repeated((), 10)

cb = function () # callback function to observe training
  @show loss()

pa = [1.0]
u0 = [3.0]
opt = ADAM(0.1)
println("Starting to train")


Flux.@epochs 10 Flux.train!(loss, params([pa]), data, opt; cb = cb)
ChrisRackauckas commented 4 years ago

@DhairyaLGandhi I think the map adjoint doesn't correctly ignore nothing's going backwards, could you take a look at this?

jc-audet commented 3 years ago

I have a similar issue as this thread in DiffEqFlux:


with something resembling the MWE in this thread. Was any progress made in the past months?

Thank you

DhairyaLGandhi commented 3 years ago

Could we test with ?

ChrisRackauckas commented 3 years ago

That branch doesn't seem to help. In fact, I'm a bit puzzled and made another similar example to work with first:

using OrdinaryDiffEq, DiffEqSensitivity, Flux, DiffEqGPU, StaticArrays, CUDA

function model()
  prob = ODEProblem((du, u, p, t) -> du[1] = 1.01 * u[1] * p[1] * p[2], u0, (0.0, 1.0), pa)

  function prob_func(prob, i, repeat)

  ensemble_prob = EnsembleProblem(prob, prob_func = prob_func)
  solve(ensemble_prob, Tsit5(), EnsembleGPUArray(0.0), saveat = 0.1, trajectories = 10, sensealg = ForwardDiffSensitivity(convert_tspan=false))

# loss function
loss() = sum(abs2,1.0.-Array(model()))

data = Iterators.repeated((), 10)

cb = function () # callback function to observe training
  @show loss()

pa = [1.0,2.0]
u0 = [3.0]
opt = ADAM(0.1)
println("Starting to train")


Flux.@epochs 10 Flux.train!(loss, params([pa]), data, opt; cb = cb)

In the adjoint I specify, i.e. ZygoteRules.@adjoint function batch_solve_up(ensembleprob,probs,alg,ensemblealg,I,u0,p;kwargs...), I have that:

(size(Array(VectorOfArray(adj))), size(p)) = ((2, 10), (2, 10))

So I know that what I'm pulling back is the same size as p (correct? I assume Zygote doesn't do something crazy on matrices?). You would think that's working, but it gets all the way back to the Flux update! code where it sees

(x, gs[x]) = ([1.0, 2.0], [46839.635021615635; 23419.817510807818])

saying that the derivative somehow adjointed on its own... what?

ChrisRackauckas commented 3 years ago

oh wait, remembering that Zygote's adjoints for comprehensions are incorrect I got rid of the comprehensions. See that last commit. That's all I needed to fix that issue. So I think comprehensions incorrectly transpose variables behind pulled back. @DhairyaLGandhi you might want to take a look at that today and try to find a smaller reproducer since that is an issue that keeps coming up.

ChrisRackauckas commented 3 years ago

The error isn't reproducible so I'm just going to merge, but @vchuravy it would be good to know why KernelAbstractions.jl cannot compile sometimes, and where it decides it can't is seemingly random, dependent on the computer, how many functions were ran before it, and just how many times a code has been ran. I don't remember it being unstable like that.

ChrisRackauckas commented 3 years ago

Seems like the test issue was just changing inbounds semantics between different environments.

DhairyaLGandhi commented 3 years ago

Im having trouble reproducing the issue, I see you've gotten rid of comprehensions but is there a more minimal example that I can use?

ChrisRackauckas commented 3 years ago

There isn't a more minimal example I could find.