FastChain in loss #198

Closed mkalia94 closed 4 years ago

mkalia94 commented 4 years ago


I was following the "Universal Differential Equations for Neural Optimal Control" example and wanted to include the neural network ann as an extra term in loss as follows:

using DiffEqFlux, Flux, Optim, OrdinaryDiffEq

u0 = Float32(1.1)
tspan = (0.0f0,25.0f0)

ann = FastChain(FastDense(2,16,tanh), FastDense(16,16,tanh), FastDense(16,1))
p1 = initial_params(ann)
p2 = Float32[0.5,-0.5]
p3 = [p1;p2]
θ = Float32[u0;p3]

function dudt_(du,u,p,t)
    x, y = u
    du[1] = ann(u,p[1:length(p1)])[1]
    du[2] = p[end-1]*y + p[end]*x
prob = ODEProblem(dudt_,u0,tspan,p3)

function predict_adjoint(θ)
loss_adjoint(θ) = sum(abs2,predict_adjoint(θ)[2,:].-1) + sum(abs2,ann(predict_adjoint(θ),θ[2:end]))
l = loss_adjoint(θ)

cb = function (θ,l)
  return false

# Display the ODE with the current parameter values.

loss1 = loss_adjoint(θ)
res = DiffEqFlux.sciml_train(loss_adjoint, θ, BFGS(initial_stepnorm=0.01), cb = cb)

However, this results in the following error:

ERROR: LoadError: ArgumentError: number of columns of each array must match (got (1, 26))

which I do not understand..any help would be really appreciated! Thanks for the great package :)

mkalia94 commented 4 years ago

I have a scrappy solution, which is to replace the loss by:

function loss_adjoint(θ)
    p = predict_adjoint(θ)
    l = 0
    for i=1:length(p[1,:])
        l  = l + sum(abs2,ann(p[:,i],θ[2:end]))

    l = l+ sum(abs2,predict_adjoint(θ)[2,:].-1)
    return l

Is there a better way to do this?

ChrisRackauckas commented 4 years ago

This is a duplicate of . Essentially, FastChain can act on a Vector but it doesn't work so well on a Matrix, even though neural networks generally "can" work on a matrix. This definition of the backpass hasn't been matrix-proofed, but if someone has like 30 minutes to dig through it then it shouldn't be too difficult: I believe it just needs a sum(?,dims=2) on one of the pullback terms.

mkalia94 commented 4 years ago

Thanks, also the quick fix above doesn't work if ann maps to a higher dimension. Say

ann = FastChain(FastDense(2,16,tanh), FastDense(16,16,tanh), FastDense(16,2))

Then looping in the new loss with initialization l = [0;0] results in a 'mutating arrays not supported error'. Is there a way around it?

mkalia94 commented 4 years ago

The above issue can be fixed with

sum(abs2,hcat([ann(p[:,i],θ[2:end]) for i in 1:length(p[1,:])]...))

I don't have much experience with Julia so I won't be able to resolve the Fast Chain action on a Matrix, but this is a quick fix for now.