SciML / ComponentArrays.jl

Arrays with arbitrarily nested named components.
MIT License
296 stars 35 forks source link

Error differentiating functions with in-place mutation #78

Closed Antomek closed 3 years ago

Antomek commented 3 years ago

Hello,

Thank you for the work on this very useful project!

I am running into an issue where I cannot use ComponentArrays with a function which mutates its du.

As a MWE, running this example, but replacing

function dudt(u, p, t)
    @unpack L1, L2 = p
    return L2.W * tanh.(L1.W * u.^3 .+ L1.b) .+ L2.b
end

with

function dudt(du, u, p, t)
    @unpack L1, L2 = p
    du .= L2.W * tanh.(L1.W * u.^3 .+ L1.b) .+ L2.b
end

yields the error:

ERROR: LoadError: type TrackedArray has no field L1

Am I doing something wrong? Can this be fixed somehow?

jonniedie commented 3 years ago

Oh, that’s related to #37 . ReverseDiff wraps a Tracked array type around the ComponentArray and that’s what messes with things. I actually wrote a way to handle this here, but I kinda forgot about it so I don’t remember whether it actually worked or not 😅. For now, you can try that out and see if it works for you. I’ll try to take a look at it again soon so it can actually make it into the package.

If that doesn’t work, you can use ForwardDiff (which might be slow for something like a neural ode) or switch to out-of-place and use Zygote as in the example. Generally for things like neural ODEs, the matrix operations are expensive enough that having an extra allocation for using the out-of-place form isn’t going to allow you down much, as far as I know.

bgroenks96 commented 3 years ago

I was just about to ask what the status of #37 is :D