FluxML / Zygote.jl

21st century AD
https://fluxml.ai/Zygote.jl/
Other
1.48k stars 213 forks source link

∇getindex mutates, causing issues with higher order AD over getindex. #820

Closed jlmaccal closed 1 year ago

jlmaccal commented 4 years ago

I'm new to flux/zygote/julia. I'm trying to develop a model that looks something like below.

I have a network that produces two outputs, A and B. The gradient of A with respect to the inputs is part of my loss function, along with other terms that depend on B. I've just summed things here for simplicity, but my actual model produces the same error.

using Flux
using Zygote

net = Chain(
    Dense(2, 100, relu),
    Dense(100, 100, relu),
    Dense(100, 100, relu),
    Dense(100, 4),
    x -> (A=x[1, :], B=x[2:end, :])
)
θ, builder = Flux.destructure(net)

x = randn(Float32, 2, 16)

function predict(θ, x, builder)
    net = builder(θ)
    results, pullback = Zygote.pullback(net, x)
    A = results.A
    B = results.B
    ∇A = pullback((A=ones(eltype(A), size(A)), B=nothing))[1]
    a = sum(∇A; dims=1)
    b = sum(B; dims=1)
    return a + b
end

Zygote.gradient(θ -> sum(abs2, predict(θ, x, builder)), θ)

The error is ERROR: LoadError: Mutating arrays is not supported, which comes from the x -> (A=x[1, :], B=x[2:end, :]) line in the network, but I don't understand where the mutation is coming from.

I gather from a number of issues here and threads on Discourse that higher-order gradients are not well supported, but there isn't much documentation around this. As a new user, it would be extremely helpful if there was some kind of documentation / guidance about how to work around this.

On a related Discourse thread @ChrisRackauckas suggested using another AD, like ReverseDiff, but I'm can't figure out how to get the gradient that I want. Any guidance would be appreciated.

ERROR: LoadError: Mutating arrays is not supported
Stacktrace:
 [1] error(::String) at ./error.jl:33
 [2] (::Zygote.var"#368#369")(::Nothing) at /Users/jlmaccal/.julia/packages/Zygote/c0awc/src/lib/array.jl:61
 [3] (::Zygote.var"#2255#back#370"{Zygote.var"#368#369"})(::Nothing) at /Users/jlmaccal/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
 [4] materialize! at ./broadcast.jl:848 [inlined]
 [5] materialize! at ./broadcast.jl:845 [inlined]
 [6] materialize! at ./broadcast.jl:841 [inlined]
 [7] (::typeof(∂(materialize!)))(::Nothing) at /Users/jlmaccal/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
 [8] #356 at /Users/jlmaccal/.julia/packages/Zygote/c0awc/src/lib/array.jl:42 [inlined]
 [9] (::typeof(∂(λ)))(::Tuple{Array{Float32,2},Nothing,Nothing}) at /Users/jlmaccal/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
 [10] #2209#back at /Users/jlmaccal/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49 [inlined]
 [11] (::typeof(∂(λ)))(::Tuple{Nothing,Array{Float32,2},Nothing,Nothing}) at /Users/jlmaccal/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
 [12] #11 at /Users/jlmaccal/Source/Julia/NonEqOpt/scratch.jl:9 [inlined]
 [13] (::typeof(∂(λ)))(::Tuple{Nothing,Array{Float32,2}}) at /Users/jlmaccal/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
 [14] applychain at /Users/jlmaccal/.julia/packages/Flux/q3zeA/src/layers/basic.jl:36 [inlined]
 [15] (::typeof(∂(λ)))(::Tuple{Nothing,Nothing,Array{Float32,2}}) at /Users/jlmaccal/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
 [16] applychain at /Users/jlmaccal/.julia/packages/Flux/q3zeA/src/layers/basic.jl:36 [inlined]
 [17] (::typeof(∂(λ)))(::Tuple{Nothing,Nothing,Array{Float32,2}}) at /Users/jlmaccal/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
 [18] applychain at /Users/jlmaccal/.julia/packages/Flux/q3zeA/src/layers/basic.jl:36 [inlined]
 [19] (::typeof(∂(λ)))(::Tuple{Nothing,Nothing,Array{Float32,2}}) at /Users/jlmaccal/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
 [20] applychain at /Users/jlmaccal/.julia/packages/Flux/q3zeA/src/layers/basic.jl:36 [inlined]
 [21] (::typeof(∂(λ)))(::Tuple{Nothing,Nothing,Array{Float32,2}}) at /Users/jlmaccal/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
 [22] applychain at /Users/jlmaccal/.julia/packages/Flux/q3zeA/src/layers/basic.jl:36 [inlined]
 [23] (::typeof(∂(λ)))(::Tuple{Nothing,Nothing,Array{Float32,2}}) at /Users/jlmaccal/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
 [24] Chain at /Users/jlmaccal/.julia/packages/Flux/q3zeA/src/layers/basic.jl:38 [inlined]
 [25] (::typeof(∂(λ)))(::Tuple{Nothing,Array{Float32,2}}) at /Users/jlmaccal/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
 [26] #41 at /Users/jlmaccal/.julia/packages/Zygote/c0awc/src/compiler/interface.jl:45 [inlined]
 [27] (::typeof(∂(λ)))(::Tuple{Array{Float32,2}}) at /Users/jlmaccal/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
 [28] predict at /Users/jlmaccal/Source/Julia/NonEqOpt/scratch.jl:20 [inlined]
 [29] (::typeof(∂(predict)))(::Array{Float32,2}) at /Users/jlmaccal/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
 [30] #13 at /Users/jlmaccal/Source/Julia/NonEqOpt/scratch.jl:26 [inlined]
 [31] (::typeof(∂(#13)))(::Float32) at /Users/jlmaccal/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
 [32] (::Zygote.var"#41#42"{typeof(∂(#13))})(::Float32) at /Users/jlmaccal/.julia/packages/Zygote/c0awc/src/compiler/interface.jl:45
 [33] gradient(::Function, ::Array{Float32,1}) at /Users/jlmaccal/.julia/packages/Zygote/c0awc/src/compiler/interface.jl:54
 [34] top-level scope at /Users/jlmaccal/Source/Julia/NonEqOpt/scratch.jl:26
 [35] include_string(::Function, ::Module, ::String, ::String) at ./loading.jl:1091
 [36] invokelatest(::Any, ::Any, ::Vararg{Any,N} where N; kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at ./essentials.jl:710
 [37] invokelatest(::Any, ::Any, ::Vararg{Any,N} where N) at ./essentials.jl:709
 [38] inlineeval(::Module, ::String, ::Int64, ::Int64, ::String; softscope::Bool) at /Users/jlmaccal/.vscode/extensions/julialang.language-julia-1.0.8/scripts/packages/VSCodeServer/src/eval.jl:83
 [39] (::VSCodeServer.var"#43#45"{VSCodeServer.ReplRunCodeRequestParams,String,Int64,Int64,String,Module,Bool})() at /Users/jlmaccal/.vscode/extensions/julialang.language-julia-1.0.8/scripts/packages/VSCodeServer/src/eval.jl:45
 [40] withpath(::VSCodeServer.var"#43#45"{VSCodeServer.ReplRunCodeRequestParams,String,Int64,Int64,String,Module,Bool}, ::String) at /Users/jlmaccal/.vscode/extensions/julialang.language-julia-1.0.8/scripts/packages/VSCodeServer/src/repl.jl:118
 [41] (::VSCodeServer.var"#42#44"{VSCodeServer.ReplRunCodeRequestParams,String,Int64,Int64,String,Module,Bool,Bool})() at /Users/jlmaccal/.vscode/extensions/julialang.language-julia-1.0.8/scripts/packages/VSCodeServer/src/eval.jl:43
 [42] hideprompt(::VSCodeServer.var"#42#44"{VSCodeServer.ReplRunCodeRequestParams,String,Int64,Int64,String,Module,Bool,Bool}) at /Users/jlmaccal/.vscode/extensions/julialang.language-julia-1.0.8/scripts/packages/VSCodeServer/src/repl.jl:36
 [43] repl_runcode_request(::VSCodeServer.JSONRPC.JSONRPCEndpoint{Base.PipeEndpoint,Base.PipeEndpoint}, ::VSCodeServer.ReplRunCodeRequestParams) at /Users/jlmaccal/.vscode/extensions/julialang.language-julia-1.0.8/scripts/packages/VSCodeServer/src/eval.jl:23
 [44] dispatch_msg(::VSCodeServer.JSONRPC.JSONRPCEndpoint{Base.PipeEndpoint,Base.PipeEndpoint}, ::VSCodeServer.JSONRPC.MsgDispatcher, ::Dict{String,Any}) at /Users/jlmaccal/.vscode/extensions/julialang.language-julia-1.0.8/scripts/packages/JSONRPC/src/typed.jl:66
 [45] macro expansion at /Users/jlmaccal/.vscode/extensions/julialang.language-julia-1.0.8/scripts/packages/VSCodeServer/src/VSCodeServer.jl:95 [inlined]
 [46] (::VSCodeServer.var"#61#63"{Bool,String})() at ./task.jl:356
in expression starting at /Users/jlmaccal/Source/Julia/NonEqOpt/scratch.jl:26
ChrisRackauckas commented 3 years ago

Your example can be reduced even more:

using Flux
using Zygote

net = Chain(
    Dense(2, 100, relu),
    Dense(100, 100, relu),
    Dense(100, 100, relu),
    Dense(100, 4),
    x -> x[1, :]
)
θ, builder = Flux.destructure(net)

x = randn(Float32, 2, 16)

function predict(θ, x, builder)
    net = builder(θ)
    A, pullback = Zygote.pullback(net, x)
    ∇A = pullback(ones(eltype(A), size(A)))[1]
    a = sum(∇A)
end

Zygote.gradient(θ -> predict(θ, x, builder), θ)

The issue is just that you can't nest Zygote. If you make the outer differentiation use ReverseDiff.jl and the inner one use Zygote.jl you're fine though.

jlmaccal commented 3 years ago

@ChrisRackauckas Thanks for the reply. I can see now that many of the pullbacks defined for common operations, like getindex mutate in-place, which prevents Zygote from being nested.

Do you know if there is any plan to address this? It seems to me that loss functions that use the gradients of a scalar field could be quite common when modelling physical systems with conserved quantities, as with Hamiltonian NNs.

I have tried using combining various ADs as you suggest, but keep running into problems. I give some examples of things I've tried below. My apologies for their length, but I feel it might help to show what I'm actually trying to accomplish.

For context, I am trying to model the power dissipated by a non-equilibrium thermodynamic system undergoing some control protocol. This is modelled with two components. The first is a conservative term, that depends on the directional derivative of a scalar field (the free energy). The second is a dissipative term that depends on a positive definite friction tensor. In the spirit of SciML, I'm trying to model both the free energy and the friction tensor using NNs.

jlmaccal commented 3 years ago

Here is one approach I tried using ForwardDiff to calculate the inner directional derivative and Zygote for the outer gradient.

If I don't include DiffEqFlux, I get an error about no method matching *(::NamedTuple..., which is addressed by one of the adjoint definitions in DiffEqFlux.

However, this does not give me the correct gradient. Instead, all of the gradients for the ξnet parameters are zero. If I change return Pcons + Pdiss to return Pdiss, then I get the correct (or at least non-zero) gradient for the ξnet parameters, but the Fnet gradients are zero.

using Flux
using Zygote
using DiffEqFlux
using ForwardDiff
using NNlib
using LinearAlgebra
using Statistics

export predictpower, create_Fnetwork, create_ξnetwork, combine_networks

struct Builder{R1,R2}
    re1::R1
    re2::R2
    n::Int64
end

(builder::Builder)(p) = begin
    p1 = p[1:builder.n]
    p2 = p[(builder.n + 1):end]
    return (builder.re1(p1), builder.re2(p2))
end

function combine_models(m1, m2)
    p1, re1 = Flux.destructure(m1)
    p2, re2 = Flux.destructure(m2)
    n = size(p1)[1]
    p = [p1; p2]
    builder = Builder(re1, re2, n)
    return (p, builder)
end

struct DirectionalDerivative{F, V}
    f::F
    direction::V
end
const DD = DirectionalDerivative

function (dd::DD)(pt)
    let dd=dd
        ForwardDiff.derivative(0) do h
            dd.f(pt + h * dd.direction)
        end
    end
end

function predictpower(x, θ, builder)
    n, nbatch = size(x)
    @assert n % 2 == 0
    ncontrol = n ÷ 2

    # Unpack the inputs
    Fnet, ξnet = builder(θ)
    λ = x[1:ncontrol, :]
    dλ = x[ncontrol + 1:end, :]

    # Compute the directional derivative dλ⋅∇F
    Pcons = DD(Fnet, dλ)(λ)

    # Reshape to column / row vectors
    dλ = reshape(dλ, ncontrol, 1, nbatch)
    dλT = permutedims(dλ, [2, 1, 3])

    # Calculate the dissaptive part of the power
    # Pdiss = dλ^T ⋅ ξ ⋅ dλ
    ξ = ξnet(λ)
    Pdiss = batched_mul(batched_mul(dλT, ξ), dλ)

    Pcons = reshape(Pcons, :)
    Pdiss = reshape(Pdiss, :)
    return Pcons + Pdiss
end

function create_Fnetwork(controldim, hiddendim, hiddenlayers)
    initial = Dense(controldim, hiddendim, relu)
    layers = [Dense(hiddendim, hiddendim, relu) for i = 1:hiddenlayers]
    final = Dense(hiddendim, 1)
    return Chain(initial, layers..., final)
end

function create_ξnetwork(controldim, hiddendim, hiddenlayers)
    componentdim = controldim * (controldim - 1) ÷ 2 + controldim
    initial = Dense(controldim, hiddendim, relu)
    layers = [Dense(hiddendim, hiddendim, relu) for i = 1:hiddenlayers]
    final = Dense(hiddendim, componentdim)
    posdef = VecToPosDef(componentdim, controldim)
    return Chain(initial, layers..., final, posdef)
end

function create_network(controldim, hiddendim, hiddenlayers)
    componentdim = controldim * (controldim - 1) ÷ 2 + controldim
    initial = Dense(controldim, hiddendim, relu)
    layers = [Dense(hiddendim, hiddendim, relu) for i = 1:hiddenlayers]
    penultimate = Dense(hiddendim, componentdim + 1)
    posdef = VecToPosDef(componentdim, controldim)
    function output(x)
        F = reshape(x[1, :], 1, :)
        ξ = posdef(x[2:end, :])
        return (F, ξ)
    end
    return Chain(initial, layers..., penultimate, output)
end

"""
VecToPosDef(indim, n)

Convert a vector to a positive definite matrix.

Take `indims` dimensional batch of vectors and convert to
a batch of `(n, n)`` positive definite matrices. The dimensions
must much sch that `indim == n*(n-1)/2 + n`. The entries
of the input are treated as elements of lower triangular
matrix. The diagonal elements are exponentated to unsure
positivity.
"""
struct VecToPosDef
    indim::Int64
    n::Int64

    function VecToPosDef(indim, n)
        @assert indim == n * (n - 1) ÷ 2 + n
        return new(indim, n)
    end
end

function (lpd::VecToPosDef)(x::AbstractArray)
    indim, n_batch = size(x)
    @assert indim == lpd.indim

    # Zygote does not support mutation of arrays,
    # so we need to use a Buffer object, which does.
    out = Zygote.Buffer(x, lpd.n, lpd.n, n_batch)

    # Set the upper triangle to zero.
    for i = 1:lpd.n
        for j = i + 1:lpd.n
            for k = 1:n_batch
                out[i, j, k] = 0.0
            end
        end
    end

    i = 1
    # Compute the diagonal.
    # Exponentiate to ensure > 0.
    for j = 1:lpd.n
        out[j, j, :] = exp.(x[i, :])
        i += 1
    end

    # Compute the lower triangle.
    for j = 1:lpd.n
        for k = 1:(j - 1)
            out[j, k, :] = x[i, :]
            i += 1
        end
    end
    # Turn the buffer back into an array
    out = copy(out)
    return batched_mul(out, permutedims(out, [2, 1, 3]))
end

# Test code

Fnet = create_Fnetwork(2, 128, 2)
ξnet = create_ξnetwork(2, 128, 2)
θ, builder = combine_models(Fnet, ξnet)

x = randn(Float32, 4, 128)

function loss(x, θ, builder)
    power = predictpower(x, θ, builder)
    return mean(power.^2)
end

grad = Zygote.gradient(p -> loss(x, p, builder), θ)[1]
grad = getindex.(ForwardDiff.partials.(grad),1)
jlmaccal commented 3 years ago

This version tries to use a single network with Zygote for the inner gradient and ReverseDiff for the outer.

It fails with (full traceback below): ERROR: LoadError: ArgumentError: indexed assignment with a single value to many locations is not supported; perhaps use broadcasting.=instead?.

using Flux
using Zygote
using ReverseDiff
using NNlib
using LinearAlgebra

export predictpower, create_Fnetwork, create_ξnetwork, combine_networks

function predictpower(x, θ, builder)
    n, nbatch = size(x)
    @assert n % 2 == 0
    ncontrol = n ÷ 2

    # Unpack the inputs
    net = builder(θ)
    λ = x[1:ncontrol, :]
    dλ = x[ncontrol + 1:end, :]

    # Forward pass
    results, pullback = Zygote.pullback(net, λ)
    F = results.F
    ξ = results.ξ
    ∇F = pullback((F = ones(eltype(F), size(F)), ξ = nothing))
    ∇F = reshape(∇F, 1, :)

    # Reshape to column / row vectors
    dλ = reshape(dλ, ncontrol, 1, nbatch)
    dλT = permutedims(dλ, [2, 1, 3])

    # Compute the conservative part of the power
    Pcons = batched_mul(dλ, ∇F)

    # Calculate the dissaptive part of the power
    # Pdiss = dλ^T ⋅ ξ ⋅ dλ
    ξ = ξnet(λ)
    Pdiss = batched_mul(batched_mul(dλT, ξ), dλ)

    Pcons = reshape(Pcons, :)
    Pdiss = reshape(Pdiss, :)
    return Pcons + Pdiss
end

function create_network(controldim, hiddendim, hiddenlayers)
    componentdim = controldim * (controldim - 1) ÷ 2 + controldim
    initial = Dense(controldim, hiddendim, relu)
    layers = [Dense(hiddendim, hiddendim, relu) for i = 1:hiddenlayers]
    penultimate = Dense(hiddendim, componentdim + 1)
    posdef = VecToPosDef(componentdim, controldim)
    function output(x)
        F = reshape(x[1, :], 1, :)
        ξ = posdef(x[2:end, :])
        return (F, ξ)
    end
    return Chain(initial, layers..., penultimate, output)
end

"""
VecToPosDef(indim, n)

Convert a vector to a positive definite matrix.

Take `indims` dimensional batch of vectors and convert to
a batch of `(n, n)`` positive definite matrices. The dimensions
must much sch that `indim == n*(n-1)/2 + n`. The entries
of the input are treated as elements of lower triangular
matrix. The diagonal elements are exponentated to unsure
positivity.
"""
struct VecToPosDef
    indim::Int64
    n::Int64

    function VecToPosDef(indim, n)
        @assert indim == n * (n - 1) ÷ 2 + n
        return new(indim, n)
    end
end

function (lpd::VecToPosDef)(x::AbstractArray)
    indim, n_batch = size(x)
    @assert indim == lpd.indim

    # Zygote does not support mutation of arrays,
    # so we need to use a Buffer object, which does.
    out = Zygote.Buffer(x, lpd.n, lpd.n, n_batch)

    # Set the upper triangle to zero.
    for i = 1:lpd.n
        for j = i + 1:lpd.n
            for k = 1:n_batch
                out[i, j, k] = 0.0
            end
        end
    end

    i = 1
    # Compute the diagonal.
    # Exponentiate to ensure > 0.
    for j = 1:lpd.n
        out[j, j, :] = exp.(x[i, :])
        i += 1
    end

    # Compute the lower triangle.
    for j = 1:lpd.n
        for k = 1:(j - 1)
            out[j, k, :] = x[i, :]
            i += 1
        end
    end
    # Turn the buffer back into an array
    out = copy(out)
    return batched_mul(out, permutedims(out, [2, 1, 3]))
end

# Test it
net = create_network(2, 128, 2)
θ, builder = Flux.destructure(net)

x = randn(Float32, 4, 128)

function loss(x, θ, builder)
    power = predictpower(x, θ, builder)
    return mean(power.^2)
end

grad = ReverseDiff.gradient(θ -> loss(x, θ, builder), θ)

Here is the traceback:

ERROR: LoadError: ArgumentError: indexed assignment with a single value to many locations is not supported; perhaps use broadcasting `.=` instead?
Stacktrace:
 [1] setindex_shape_check(::Base.Broadcast.Broadcasted{ReverseDiff.TrackedStyle,Nothing,typeof(exp),Tuple{ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}}}, ::Int64, ::Int64, ::Int64) at ./indices.jl:258
 [2] macro expansion at ./multidimensional.jl:795 [inlined]
 [3] _unsafe_setindex!(::IndexLinear, ::Array{ReverseDiff.TrackedReal{Float32,Float32,ReverseDiff.TrackedArray{Float32,Float32,2,Array{Float32,2},Array{Float32,2}}},3}, ::Base.Broadcast.Broadcasted{ReverseDiff.TrackedStyle,Nothing,typeof(exp),Tuple{ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}}}, ::Int64, ::Int64, ::Base.Slice{Base.OneTo{Int64}}) at ./multidimensional.jl:789
 [4] _setindex! at ./multidimensional.jl:785 [inlined]
 [5] setindex! at ./abstractarray.jl:1153 [inlined]
 [6] setindex!(::Zygote.Buffer{ReverseDiff.TrackedReal{Float32,Float32,ReverseDiff.TrackedArray{Float32,Float32,2,Array{Float32,2},Array{Float32,2}}},Array{ReverseDiff.TrackedReal{Float32,Float32,ReverseDiff.TrackedArray{Float32,Float32,2,Array{Float32,2},Array{Float32,2}}},3}}, ::Base.Broadcast.Broadcasted{ReverseDiff.TrackedStyle,Nothing,typeof(exp),Tuple{ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}}}, ::Int64, ::Int64, ::Function) at /Users/jlmaccal/.julia/packages/Zygote/c0awc/src/tools/buffer.jl:51
 [7] adjoint at /Users/jlmaccal/.julia/packages/Zygote/c0awc/src/lib/buffer.jl:15 [inlined]
 [8] _pullback at /Users/jlmaccal/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:47 [inlined]
 [9] VecToPosDef at /Users/jlmaccal/Source/Julia/NonEqOpt/scratch2.jl:100 [inlined]
 [10] _pullback(::Zygote.Context, ::VecToPosDef, ::ReverseDiff.TrackedArray{Float32,Float32,2,Array{Float32,2},Array{Float32,2}}) at /Users/jlmaccal/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
 [11] output at /Users/jlmaccal/Source/Julia/NonEqOpt/scratch2.jl:51 [inlined]
 [12] applychain at /Users/jlmaccal/.julia/packages/Flux/q3zeA/src/layers/basic.jl:36 [inlined] (repeats 5 times)
 [13] Chain at /Users/jlmaccal/.julia/packages/Flux/q3zeA/src/layers/basic.jl:38 [inlined]
 [14] _pullback at /Users/jlmaccal/.julia/packages/Zygote/c0awc/src/compiler/interface.jl:38 [inlined]
 [15] pullback(::Chain{Tuple{Dense{typeof(relu),ReverseDiff.TrackedArray{Float32,Float32,2,Array{Float32,2},Array{Float32,2}},ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}},Dense{typeof(relu),ReverseDiff.TrackedArray{Float32,Float32,2,Array{Float32,2},Array{Float32,2}},ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}},Dense{typeof(relu),ReverseDiff.TrackedArray{Float32,Float32,2,Array{Float32,2},Array{Float32,2}},ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}},Dense{typeof(identity),ReverseDiff.TrackedArray{Float32,Float32,2,Array{Float32,2},Array{Float32,2}},ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}},var"#output#13"{VecToPosDef}}}, ::Array{Float32,2}) at /Users/jlmaccal/.julia/packages/Zygote/c0awc/src/compiler/interface.jl:44
 [16] predictpower(::Array{Float32,2}, ::ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}, ::Flux.var"#34#36"{Chain{Tuple{Dense{typeof(relu),Array{Float32,2},Array{Float32,1}},Dense{typeof(relu),Array{Float32,2},Array{Float32,1}},Dense{typeof(relu),Array{Float32,2},Array{Float32,1}},Dense{typeof(identity),Array{Float32,2},Array{Float32,1}},var"#output#13"{VecToPosDef}}}}) at /Users/jlmaccal/Source/Julia/NonEqOpt/scratch2.jl:20
 [17] loss(::Array{Float32,2}, ::ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}, ::Function) at /Users/jlmaccal/Source/Julia/NonEqOpt/scratch2.jl:124
 [18] (::var"#14#15")(::ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}) at /Users/jlmaccal/Source/Julia/NonEqOpt/scratch2.jl:128
 [19] ReverseDiff.GradientTape(::var"#14#15", ::Array{Float32,1}, ::ReverseDiff.GradientConfig{ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}}) at /Users/jlmaccal/.julia/packages/ReverseDiff/jFRo1/src/api/tape.jl:199
 [20] gradient(::Function, ::Array{Float32,1}, ::ReverseDiff.GradientConfig{ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}}) at /Users/jlmaccal/.julia/packages/ReverseDiff/jFRo1/src/api/gradients.jl:22 (repeats 2 times)
 [21] top-level scope at /Users/jlmaccal/Source/Julia/NonEqOpt/scratch2.jl:128
 [22] include_string(::Function, ::Module, ::String, ::String) at ./loading.jl:1091
 [23] invokelatest(::Any, ::Any, ::Vararg{Any,N} where N; kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at ./essentials.jl:710
 [24] invokelatest(::Any, ::Any, ::Vararg{Any,N} where N) at ./essentials.jl:709
 [25] inlineeval(::Module, ::String, ::Int64, ::Int64, ::String; softscope::Bool) at /Users/jlmaccal/.vscode/extensions/julialang.language-julia-1.0.8/scripts/packages/VSCodeServer/src/eval.jl:83
 [26] (::VSCodeServer.var"#43#45"{VSCodeServer.ReplRunCodeRequestParams,String,Int64,Int64,String,Module,Bool})() at /Users/jlmaccal/.vscode/extensions/julialang.language-julia-1.0.8/scripts/packages/VSCodeServer/src/eval.jl:45
 [27] withpath(::VSCodeServer.var"#43#45"{VSCodeServer.ReplRunCodeRequestParams,String,Int64,Int64,String,Module,Bool}, ::String) at /Users/jlmaccal/.vscode/extensions/julialang.language-julia-1.0.8/scripts/packages/VSCodeServer/src/repl.jl:118
 [28] (::VSCodeServer.var"#42#44"{VSCodeServer.ReplRunCodeRequestParams,String,Int64,Int64,String,Module,Bool,Bool})() at /Users/jlmaccal/.vscode/extensions/julialang.language-julia-1.0.8/scripts/packages/VSCodeServer/src/eval.jl:43
 [29] hideprompt(::VSCodeServer.var"#42#44"{VSCodeServer.ReplRunCodeRequestParams,String,Int64,Int64,String,Module,Bool,Bool}) at /Users/jlmaccal/.vscode/extensions/julialang.language-julia-1.0.8/scripts/packages/VSCodeServer/src/repl.jl:36
 [30] repl_runcode_request(::VSCodeServer.JSONRPC.JSONRPCEndpoint{Base.PipeEndpoint,Base.PipeEndpoint}, ::VSCodeServer.ReplRunCodeRequestParams) at /Users/jlmaccal/.vscode/extensions/julialang.language-julia-1.0.8/scripts/packages/VSCodeServer/src/eval.jl:23
 [31] dispatch_msg(::VSCodeServer.JSONRPC.JSONRPCEndpoint{Base.PipeEndpoint,Base.PipeEndpoint}, ::VSCodeServer.JSONRPC.MsgDispatcher, ::Dict{String,Any}) at /Users/jlmaccal/.vscode/extensions/julialang.language-julia-1.0.8/scripts/packages/JSONRPC/src/typed.jl:66
 [32] macro expansion at /Users/jlmaccal/.vscode/extensions/julialang.language-julia-1.0.8/scripts/packages/VSCodeServer/src/VSCodeServer.jl:95 [inlined]
 [33] (::VSCodeServer.var"#61#63"{Bool,String})() at ./task.jl:356
in expression starting at /Users/jlmaccal/Source/Julia/NonEqOpt/scratch2.jl:128
DhairyaLGandhi commented 3 years ago

@ChrisRackauckas the MWE can actually be reduced to only include the x -> x[1,:], since that is what causes the mutation error


julia> gradient(rand(3,3)) do p
         gradient(p) do p
           sum(p[1, :])
         end[1] |> sum
       end
axsk commented 3 years ago

I'm stuck with the same problem, how did you solve yours?

https://discourse.julialang.org/t/flux-higher-order-derivatives-and-forward-mode/38805/3 seems to have a similar problem but the workaround there (extracting the NN paramteres and optimzing externally) doesn't work in my case since I want to stick in the Flux training framework.

I tried integrating that approach to write an Zygote.@adjoint but could not work out how to mangle the closures.

axsk commented 3 years ago

So what is the actual problem preventing Zygote from computing higher order derivatives?

DhairyaLGandhi commented 3 years ago

So the "issue" is that zygote uses mutation on the adjoint of the getindex. Hmm, let me think about if we can handle it better

axsk commented 3 years ago

I actually thought the problem lied elsewhere, but using sum instead of getindex seems to work here.

x = rand(2)
m = Chain(Dense(2,1))

Flux.gradient(params(m)) do
    gradient(m,x) |> sum |> sum
end

Edit: loss(f,x) = sum(abs2, Flux.gradient(x->f(x) |> sum, x) |> sum) indeed works for me as desired (although I should probably comment the use of the sums in the source :D ) So relieved, I already thought I had to switch to JAX

DhairyaLGandhi commented 3 years ago

Yeah, I think that should be fine but it is less generally correct to do, I think

axsk commented 3 years ago

A cleaner way is to extract the gradient by tuple destructuring (is it called that?) dx, = gradient(m, x)

ChrisRackauckas commented 3 years ago

The issue is that:

∇getindex(x::AbstractArray, inds) = dy -> begin
  if inds isa  NTuple{<:Any, Integer}
    allinds = eachindex(x)
    ininds(i) = i ∈ inds
    dx = ifelse.(_zero(x, typeof(dy))
    dx[inds...] = dy
  else
    dx = _zero(x, eltype(dy))
    dxv = view(dx, inds...)
    dxv .= accum.(dxv, _droplike(dy, dxv))
  end
  return (dx, map(_->nothing, inds)...)
end

These mutate. I would suggest splitting that into two separate dispatches and trying to come up with schemes that are just broadcasts or filters. If that's not easy to do, then I think a dispatch on just arrays (to avoid CuArrays) that just loops would be nice and fix the problem for most non-GPU users.

Janssena commented 3 years ago

I believe I'm facing a similar issue, where I need to use the jacobian of my prediction function with respect to an array of random variables in the loss function.

Here my code:

# produces θ for pred function
ann = Flux.Chain(
    Flux.Dense(input, 32, Flux.tanh),
    Flux.Dense(32, 32),
    Flux.Dense(32, 3),
);

function dAdt(dA, A, p, t)
    a, b, c = p
    dA[1] = -a * b * A[1] 
    dA[2] = c * b * A[1] - a * A[2]
end

function pred(θ, η, t, callback)
    p = θ .* exp.(η)
    prob = diffeq.ODEProblem(dAdt, [0., 0.], (-.1, maximum(t)), p)
    sol = diffeq.solve(prob, diffeq.Tsit5(), saveat=t, tstops=[0.], callback=callback, sensealg=des.ForwardDiffSensitivity())
    return sol[2 , :] # A[2] corresponds to y measurements
end 

∂pred_∂η(θ, η, time, callback) = Zygote.jacobian(eta -> pred(θ, eta, time, callback), η)

# p == 3x3 correlation matrix
function Obj(x, y, p, times, callbacks)
    if !isposdef(p)
        return Inf
    end

    N = length(times) # equal to the number of observations in dataset
    θ = ann(x')
    η = zeros(size(p, 1)) # test

    loss = 0.

    for i in 1:N
        ŷ = pred(θ[:, i], η, times[i], callbacks[i])
        residuals = y[i] - ŷ
        jac_eta = ∂pred_∂η(θ[:, i], η, times[i], callbacks[i]) # line 1
        loss = mean(residuals) + mean(jac_eta * p * jac_eta') # line 2
    end

    return loss
end

grad = Zygote.gradient(() -> Obj(x, y, p, times, callbacks), Flux.params(ann)) # error mutating arrays

removing line 1 and changing line 2 to loss = mean(residuals) runs fine, but calculation of the jacobian results in the mutating arrays error in zygote. Is there someone working on implementing the above comment by Chris, or is there some way I can help on this? I'm not that experienced with working on Zygote code but trying to solve the above issue.

ChrisRackauckas commented 3 years ago

https://github.com/FluxML/Zygote.jl/pull/77 is a solution that could be used.

axsk commented 3 years ago

So out of #77 we would just need the @adjoint ∇getindex part to circumvent the setindex call is that correct?

ChrisRackauckas commented 3 years ago

Yes

axsk commented 3 years ago

This is my take at Keno's approach


∇getindex(x::AbstractArray, inds) = dy -> (_zerosetindex(x, inds, dy), map(_->nothing, inds)...)

function _zerosetindex(x, inds::NTuple{<:Any, Integer}, dy)
  dx = _zero(x, typeof(dy))
  dx[inds...] = dy
  dx
end

function _zerosetindex(x, inds, dy)
  dx = _zero(x, eltype(dy))
  dxv = view(dx, inds...)
  dxv .= accum.(dxv, _droplike(dy, dxv))
  dx
end

@adjoint function _zerosetindex(x, inds, dy)
  _zerosetindex(x, inds, dy), ddx -> (nothing, nothing, ddx[inds...])
end

Keno's tests seem to run through as well. Should I put up a PR?

ChrisRackauckas commented 3 years ago

I think that would be great!

DhairyaLGandhi commented 3 years ago

We'll want to test this with GPUs, and check for performance