Closed jlmaccal closed 1 year ago
Your example can be reduced even more:
using Flux
using Zygote
net = Chain(
Dense(2, 100, relu),
Dense(100, 100, relu),
Dense(100, 100, relu),
Dense(100, 4),
x -> x[1, :]
)
θ, builder = Flux.destructure(net)
x = randn(Float32, 2, 16)
function predict(θ, x, builder)
net = builder(θ)
A, pullback = Zygote.pullback(net, x)
∇A = pullback(ones(eltype(A), size(A)))[1]
a = sum(∇A)
end
Zygote.gradient(θ -> predict(θ, x, builder), θ)
The issue is just that you can't nest Zygote. If you make the outer differentiation use ReverseDiff.jl and the inner one use Zygote.jl you're fine though.
@ChrisRackauckas Thanks for the reply. I can see now that many of the pullbacks defined for common operations, like getindex
mutate in-place, which prevents Zygote from being nested.
Do you know if there is any plan to address this? It seems to me that loss functions that use the gradients of a scalar field could be quite common when modelling physical systems with conserved quantities, as with Hamiltonian NNs.
I have tried using combining various ADs as you suggest, but keep running into problems. I give some examples of things I've tried below. My apologies for their length, but I feel it might help to show what I'm actually trying to accomplish.
For context, I am trying to model the power dissipated by a non-equilibrium thermodynamic system undergoing some control protocol. This is modelled with two components. The first is a conservative term, that depends on the directional derivative of a scalar field (the free energy). The second is a dissipative term that depends on a positive definite friction tensor. In the spirit of SciML, I'm trying to model both the free energy and the friction tensor using NNs.
Here is one approach I tried using ForwardDiff
to calculate the inner directional derivative and Zygote for the outer gradient.
If I don't include DiffEqFlux
, I get an error about no method matching *(::NamedTuple...
, which is addressed by one of the adjoint definitions in DiffEqFlux.
However, this does not give me the correct gradient. Instead, all of the gradients for the ξnet parameters are zero. If I change return Pcons + Pdiss
to return Pdiss
, then I get the correct (or at least non-zero) gradient for the ξnet parameters, but the Fnet gradients are zero.
using Flux
using Zygote
using DiffEqFlux
using ForwardDiff
using NNlib
using LinearAlgebra
using Statistics
export predictpower, create_Fnetwork, create_ξnetwork, combine_networks
struct Builder{R1,R2}
re1::R1
re2::R2
n::Int64
end
(builder::Builder)(p) = begin
p1 = p[1:builder.n]
p2 = p[(builder.n + 1):end]
return (builder.re1(p1), builder.re2(p2))
end
function combine_models(m1, m2)
p1, re1 = Flux.destructure(m1)
p2, re2 = Flux.destructure(m2)
n = size(p1)[1]
p = [p1; p2]
builder = Builder(re1, re2, n)
return (p, builder)
end
struct DirectionalDerivative{F, V}
f::F
direction::V
end
const DD = DirectionalDerivative
function (dd::DD)(pt)
let dd=dd
ForwardDiff.derivative(0) do h
dd.f(pt + h * dd.direction)
end
end
end
function predictpower(x, θ, builder)
n, nbatch = size(x)
@assert n % 2 == 0
ncontrol = n ÷ 2
# Unpack the inputs
Fnet, ξnet = builder(θ)
λ = x[1:ncontrol, :]
dλ = x[ncontrol + 1:end, :]
# Compute the directional derivative dλ⋅∇F
Pcons = DD(Fnet, dλ)(λ)
# Reshape to column / row vectors
dλ = reshape(dλ, ncontrol, 1, nbatch)
dλT = permutedims(dλ, [2, 1, 3])
# Calculate the dissaptive part of the power
# Pdiss = dλ^T ⋅ ξ ⋅ dλ
ξ = ξnet(λ)
Pdiss = batched_mul(batched_mul(dλT, ξ), dλ)
Pcons = reshape(Pcons, :)
Pdiss = reshape(Pdiss, :)
return Pcons + Pdiss
end
function create_Fnetwork(controldim, hiddendim, hiddenlayers)
initial = Dense(controldim, hiddendim, relu)
layers = [Dense(hiddendim, hiddendim, relu) for i = 1:hiddenlayers]
final = Dense(hiddendim, 1)
return Chain(initial, layers..., final)
end
function create_ξnetwork(controldim, hiddendim, hiddenlayers)
componentdim = controldim * (controldim - 1) ÷ 2 + controldim
initial = Dense(controldim, hiddendim, relu)
layers = [Dense(hiddendim, hiddendim, relu) for i = 1:hiddenlayers]
final = Dense(hiddendim, componentdim)
posdef = VecToPosDef(componentdim, controldim)
return Chain(initial, layers..., final, posdef)
end
function create_network(controldim, hiddendim, hiddenlayers)
componentdim = controldim * (controldim - 1) ÷ 2 + controldim
initial = Dense(controldim, hiddendim, relu)
layers = [Dense(hiddendim, hiddendim, relu) for i = 1:hiddenlayers]
penultimate = Dense(hiddendim, componentdim + 1)
posdef = VecToPosDef(componentdim, controldim)
function output(x)
F = reshape(x[1, :], 1, :)
ξ = posdef(x[2:end, :])
return (F, ξ)
end
return Chain(initial, layers..., penultimate, output)
end
"""
VecToPosDef(indim, n)
Convert a vector to a positive definite matrix.
Take `indims` dimensional batch of vectors and convert to
a batch of `(n, n)`` positive definite matrices. The dimensions
must much sch that `indim == n*(n-1)/2 + n`. The entries
of the input are treated as elements of lower triangular
matrix. The diagonal elements are exponentated to unsure
positivity.
"""
struct VecToPosDef
indim::Int64
n::Int64
function VecToPosDef(indim, n)
@assert indim == n * (n - 1) ÷ 2 + n
return new(indim, n)
end
end
function (lpd::VecToPosDef)(x::AbstractArray)
indim, n_batch = size(x)
@assert indim == lpd.indim
# Zygote does not support mutation of arrays,
# so we need to use a Buffer object, which does.
out = Zygote.Buffer(x, lpd.n, lpd.n, n_batch)
# Set the upper triangle to zero.
for i = 1:lpd.n
for j = i + 1:lpd.n
for k = 1:n_batch
out[i, j, k] = 0.0
end
end
end
i = 1
# Compute the diagonal.
# Exponentiate to ensure > 0.
for j = 1:lpd.n
out[j, j, :] = exp.(x[i, :])
i += 1
end
# Compute the lower triangle.
for j = 1:lpd.n
for k = 1:(j - 1)
out[j, k, :] = x[i, :]
i += 1
end
end
# Turn the buffer back into an array
out = copy(out)
return batched_mul(out, permutedims(out, [2, 1, 3]))
end
# Test code
Fnet = create_Fnetwork(2, 128, 2)
ξnet = create_ξnetwork(2, 128, 2)
θ, builder = combine_models(Fnet, ξnet)
x = randn(Float32, 4, 128)
function loss(x, θ, builder)
power = predictpower(x, θ, builder)
return mean(power.^2)
end
grad = Zygote.gradient(p -> loss(x, p, builder), θ)[1]
grad = getindex.(ForwardDiff.partials.(grad),1)
This version tries to use a single network with Zygote for the inner gradient and ReverseDiff
for the outer.
It fails with (full traceback below): ERROR: LoadError: ArgumentError: indexed assignment with a single value to many locations is not supported; perhaps use broadcasting
.=instead?
.
using Flux
using Zygote
using ReverseDiff
using NNlib
using LinearAlgebra
export predictpower, create_Fnetwork, create_ξnetwork, combine_networks
function predictpower(x, θ, builder)
n, nbatch = size(x)
@assert n % 2 == 0
ncontrol = n ÷ 2
# Unpack the inputs
net = builder(θ)
λ = x[1:ncontrol, :]
dλ = x[ncontrol + 1:end, :]
# Forward pass
results, pullback = Zygote.pullback(net, λ)
F = results.F
ξ = results.ξ
∇F = pullback((F = ones(eltype(F), size(F)), ξ = nothing))
∇F = reshape(∇F, 1, :)
# Reshape to column / row vectors
dλ = reshape(dλ, ncontrol, 1, nbatch)
dλT = permutedims(dλ, [2, 1, 3])
# Compute the conservative part of the power
Pcons = batched_mul(dλ, ∇F)
# Calculate the dissaptive part of the power
# Pdiss = dλ^T ⋅ ξ ⋅ dλ
ξ = ξnet(λ)
Pdiss = batched_mul(batched_mul(dλT, ξ), dλ)
Pcons = reshape(Pcons, :)
Pdiss = reshape(Pdiss, :)
return Pcons + Pdiss
end
function create_network(controldim, hiddendim, hiddenlayers)
componentdim = controldim * (controldim - 1) ÷ 2 + controldim
initial = Dense(controldim, hiddendim, relu)
layers = [Dense(hiddendim, hiddendim, relu) for i = 1:hiddenlayers]
penultimate = Dense(hiddendim, componentdim + 1)
posdef = VecToPosDef(componentdim, controldim)
function output(x)
F = reshape(x[1, :], 1, :)
ξ = posdef(x[2:end, :])
return (F, ξ)
end
return Chain(initial, layers..., penultimate, output)
end
"""
VecToPosDef(indim, n)
Convert a vector to a positive definite matrix.
Take `indims` dimensional batch of vectors and convert to
a batch of `(n, n)`` positive definite matrices. The dimensions
must much sch that `indim == n*(n-1)/2 + n`. The entries
of the input are treated as elements of lower triangular
matrix. The diagonal elements are exponentated to unsure
positivity.
"""
struct VecToPosDef
indim::Int64
n::Int64
function VecToPosDef(indim, n)
@assert indim == n * (n - 1) ÷ 2 + n
return new(indim, n)
end
end
function (lpd::VecToPosDef)(x::AbstractArray)
indim, n_batch = size(x)
@assert indim == lpd.indim
# Zygote does not support mutation of arrays,
# so we need to use a Buffer object, which does.
out = Zygote.Buffer(x, lpd.n, lpd.n, n_batch)
# Set the upper triangle to zero.
for i = 1:lpd.n
for j = i + 1:lpd.n
for k = 1:n_batch
out[i, j, k] = 0.0
end
end
end
i = 1
# Compute the diagonal.
# Exponentiate to ensure > 0.
for j = 1:lpd.n
out[j, j, :] = exp.(x[i, :])
i += 1
end
# Compute the lower triangle.
for j = 1:lpd.n
for k = 1:(j - 1)
out[j, k, :] = x[i, :]
i += 1
end
end
# Turn the buffer back into an array
out = copy(out)
return batched_mul(out, permutedims(out, [2, 1, 3]))
end
# Test it
net = create_network(2, 128, 2)
θ, builder = Flux.destructure(net)
x = randn(Float32, 4, 128)
function loss(x, θ, builder)
power = predictpower(x, θ, builder)
return mean(power.^2)
end
grad = ReverseDiff.gradient(θ -> loss(x, θ, builder), θ)
Here is the traceback:
ERROR: LoadError: ArgumentError: indexed assignment with a single value to many locations is not supported; perhaps use broadcasting `.=` instead?
Stacktrace:
[1] setindex_shape_check(::Base.Broadcast.Broadcasted{ReverseDiff.TrackedStyle,Nothing,typeof(exp),Tuple{ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}}}, ::Int64, ::Int64, ::Int64) at ./indices.jl:258
[2] macro expansion at ./multidimensional.jl:795 [inlined]
[3] _unsafe_setindex!(::IndexLinear, ::Array{ReverseDiff.TrackedReal{Float32,Float32,ReverseDiff.TrackedArray{Float32,Float32,2,Array{Float32,2},Array{Float32,2}}},3}, ::Base.Broadcast.Broadcasted{ReverseDiff.TrackedStyle,Nothing,typeof(exp),Tuple{ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}}}, ::Int64, ::Int64, ::Base.Slice{Base.OneTo{Int64}}) at ./multidimensional.jl:789
[4] _setindex! at ./multidimensional.jl:785 [inlined]
[5] setindex! at ./abstractarray.jl:1153 [inlined]
[6] setindex!(::Zygote.Buffer{ReverseDiff.TrackedReal{Float32,Float32,ReverseDiff.TrackedArray{Float32,Float32,2,Array{Float32,2},Array{Float32,2}}},Array{ReverseDiff.TrackedReal{Float32,Float32,ReverseDiff.TrackedArray{Float32,Float32,2,Array{Float32,2},Array{Float32,2}}},3}}, ::Base.Broadcast.Broadcasted{ReverseDiff.TrackedStyle,Nothing,typeof(exp),Tuple{ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}}}, ::Int64, ::Int64, ::Function) at /Users/jlmaccal/.julia/packages/Zygote/c0awc/src/tools/buffer.jl:51
[7] adjoint at /Users/jlmaccal/.julia/packages/Zygote/c0awc/src/lib/buffer.jl:15 [inlined]
[8] _pullback at /Users/jlmaccal/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:47 [inlined]
[9] VecToPosDef at /Users/jlmaccal/Source/Julia/NonEqOpt/scratch2.jl:100 [inlined]
[10] _pullback(::Zygote.Context, ::VecToPosDef, ::ReverseDiff.TrackedArray{Float32,Float32,2,Array{Float32,2},Array{Float32,2}}) at /Users/jlmaccal/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
[11] output at /Users/jlmaccal/Source/Julia/NonEqOpt/scratch2.jl:51 [inlined]
[12] applychain at /Users/jlmaccal/.julia/packages/Flux/q3zeA/src/layers/basic.jl:36 [inlined] (repeats 5 times)
[13] Chain at /Users/jlmaccal/.julia/packages/Flux/q3zeA/src/layers/basic.jl:38 [inlined]
[14] _pullback at /Users/jlmaccal/.julia/packages/Zygote/c0awc/src/compiler/interface.jl:38 [inlined]
[15] pullback(::Chain{Tuple{Dense{typeof(relu),ReverseDiff.TrackedArray{Float32,Float32,2,Array{Float32,2},Array{Float32,2}},ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}},Dense{typeof(relu),ReverseDiff.TrackedArray{Float32,Float32,2,Array{Float32,2},Array{Float32,2}},ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}},Dense{typeof(relu),ReverseDiff.TrackedArray{Float32,Float32,2,Array{Float32,2},Array{Float32,2}},ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}},Dense{typeof(identity),ReverseDiff.TrackedArray{Float32,Float32,2,Array{Float32,2},Array{Float32,2}},ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}},var"#output#13"{VecToPosDef}}}, ::Array{Float32,2}) at /Users/jlmaccal/.julia/packages/Zygote/c0awc/src/compiler/interface.jl:44
[16] predictpower(::Array{Float32,2}, ::ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}, ::Flux.var"#34#36"{Chain{Tuple{Dense{typeof(relu),Array{Float32,2},Array{Float32,1}},Dense{typeof(relu),Array{Float32,2},Array{Float32,1}},Dense{typeof(relu),Array{Float32,2},Array{Float32,1}},Dense{typeof(identity),Array{Float32,2},Array{Float32,1}},var"#output#13"{VecToPosDef}}}}) at /Users/jlmaccal/Source/Julia/NonEqOpt/scratch2.jl:20
[17] loss(::Array{Float32,2}, ::ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}, ::Function) at /Users/jlmaccal/Source/Julia/NonEqOpt/scratch2.jl:124
[18] (::var"#14#15")(::ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}) at /Users/jlmaccal/Source/Julia/NonEqOpt/scratch2.jl:128
[19] ReverseDiff.GradientTape(::var"#14#15", ::Array{Float32,1}, ::ReverseDiff.GradientConfig{ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}}) at /Users/jlmaccal/.julia/packages/ReverseDiff/jFRo1/src/api/tape.jl:199
[20] gradient(::Function, ::Array{Float32,1}, ::ReverseDiff.GradientConfig{ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}}) at /Users/jlmaccal/.julia/packages/ReverseDiff/jFRo1/src/api/gradients.jl:22 (repeats 2 times)
[21] top-level scope at /Users/jlmaccal/Source/Julia/NonEqOpt/scratch2.jl:128
[22] include_string(::Function, ::Module, ::String, ::String) at ./loading.jl:1091
[23] invokelatest(::Any, ::Any, ::Vararg{Any,N} where N; kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at ./essentials.jl:710
[24] invokelatest(::Any, ::Any, ::Vararg{Any,N} where N) at ./essentials.jl:709
[25] inlineeval(::Module, ::String, ::Int64, ::Int64, ::String; softscope::Bool) at /Users/jlmaccal/.vscode/extensions/julialang.language-julia-1.0.8/scripts/packages/VSCodeServer/src/eval.jl:83
[26] (::VSCodeServer.var"#43#45"{VSCodeServer.ReplRunCodeRequestParams,String,Int64,Int64,String,Module,Bool})() at /Users/jlmaccal/.vscode/extensions/julialang.language-julia-1.0.8/scripts/packages/VSCodeServer/src/eval.jl:45
[27] withpath(::VSCodeServer.var"#43#45"{VSCodeServer.ReplRunCodeRequestParams,String,Int64,Int64,String,Module,Bool}, ::String) at /Users/jlmaccal/.vscode/extensions/julialang.language-julia-1.0.8/scripts/packages/VSCodeServer/src/repl.jl:118
[28] (::VSCodeServer.var"#42#44"{VSCodeServer.ReplRunCodeRequestParams,String,Int64,Int64,String,Module,Bool,Bool})() at /Users/jlmaccal/.vscode/extensions/julialang.language-julia-1.0.8/scripts/packages/VSCodeServer/src/eval.jl:43
[29] hideprompt(::VSCodeServer.var"#42#44"{VSCodeServer.ReplRunCodeRequestParams,String,Int64,Int64,String,Module,Bool,Bool}) at /Users/jlmaccal/.vscode/extensions/julialang.language-julia-1.0.8/scripts/packages/VSCodeServer/src/repl.jl:36
[30] repl_runcode_request(::VSCodeServer.JSONRPC.JSONRPCEndpoint{Base.PipeEndpoint,Base.PipeEndpoint}, ::VSCodeServer.ReplRunCodeRequestParams) at /Users/jlmaccal/.vscode/extensions/julialang.language-julia-1.0.8/scripts/packages/VSCodeServer/src/eval.jl:23
[31] dispatch_msg(::VSCodeServer.JSONRPC.JSONRPCEndpoint{Base.PipeEndpoint,Base.PipeEndpoint}, ::VSCodeServer.JSONRPC.MsgDispatcher, ::Dict{String,Any}) at /Users/jlmaccal/.vscode/extensions/julialang.language-julia-1.0.8/scripts/packages/JSONRPC/src/typed.jl:66
[32] macro expansion at /Users/jlmaccal/.vscode/extensions/julialang.language-julia-1.0.8/scripts/packages/VSCodeServer/src/VSCodeServer.jl:95 [inlined]
[33] (::VSCodeServer.var"#61#63"{Bool,String})() at ./task.jl:356
in expression starting at /Users/jlmaccal/Source/Julia/NonEqOpt/scratch2.jl:128
@ChrisRackauckas the MWE can actually be reduced to only include the x -> x[1,:]
, since that is what causes the mutation error
julia> gradient(rand(3,3)) do p
gradient(p) do p
sum(p[1, :])
end[1] |> sum
end
I'm stuck with the same problem, how did you solve yours?
https://discourse.julialang.org/t/flux-higher-order-derivatives-and-forward-mode/38805/3 seems to have a similar problem but the workaround there (extracting the NN paramteres and optimzing externally) doesn't work in my case since I want to stick in the Flux training framework.
I tried integrating that approach to write an Zygote.@adjoint
but could not work out how to mangle the closures.
So what is the actual problem preventing Zygote from computing higher order derivatives?
So the "issue" is that zygote uses mutation on the adjoint of the getindex. Hmm, let me think about if we can handle it better
I actually thought the problem lied elsewhere, but using sum
instead of getindex
seems to work here.
x = rand(2)
m = Chain(Dense(2,1))
Flux.gradient(params(m)) do
gradient(m,x) |> sum |> sum
end
Edit:
loss(f,x) = sum(abs2, Flux.gradient(x->f(x) |> sum, x) |> sum)
indeed works for me as desired (although I should probably comment the use of the sums in the source :D )
So relieved, I already thought I had to switch to JAX
Yeah, I think that should be fine but it is less generally correct to do, I think
A cleaner way is to extract the gradient by tuple destructuring (is it called that?)
dx, = gradient(m, x)
The issue is that:
∇getindex(x::AbstractArray, inds) = dy -> begin
if inds isa NTuple{<:Any, Integer}
allinds = eachindex(x)
ininds(i) = i ∈ inds
dx = ifelse.(_zero(x, typeof(dy))
dx[inds...] = dy
else
dx = _zero(x, eltype(dy))
dxv = view(dx, inds...)
dxv .= accum.(dxv, _droplike(dy, dxv))
end
return (dx, map(_->nothing, inds)...)
end
These mutate. I would suggest splitting that into two separate dispatches and trying to come up with schemes that are just broadcasts or filters. If that's not easy to do, then I think a dispatch on just arrays (to avoid CuArrays) that just loops would be nice and fix the problem for most non-GPU users.
I believe I'm facing a similar issue, where I need to use the jacobian of my prediction function with respect to an array of random variables in the loss function.
Here my code:
# produces θ for pred function
ann = Flux.Chain(
Flux.Dense(input, 32, Flux.tanh),
Flux.Dense(32, 32),
Flux.Dense(32, 3),
);
function dAdt(dA, A, p, t)
a, b, c = p
dA[1] = -a * b * A[1]
dA[2] = c * b * A[1] - a * A[2]
end
function pred(θ, η, t, callback)
p = θ .* exp.(η)
prob = diffeq.ODEProblem(dAdt, [0., 0.], (-.1, maximum(t)), p)
sol = diffeq.solve(prob, diffeq.Tsit5(), saveat=t, tstops=[0.], callback=callback, sensealg=des.ForwardDiffSensitivity())
return sol[2 , :] # A[2] corresponds to y measurements
end
∂pred_∂η(θ, η, time, callback) = Zygote.jacobian(eta -> pred(θ, eta, time, callback), η)
# p == 3x3 correlation matrix
function Obj(x, y, p, times, callbacks)
if !isposdef(p)
return Inf
end
N = length(times) # equal to the number of observations in dataset
θ = ann(x')
η = zeros(size(p, 1)) # test
loss = 0.
for i in 1:N
ŷ = pred(θ[:, i], η, times[i], callbacks[i])
residuals = y[i] - ŷ
jac_eta = ∂pred_∂η(θ[:, i], η, times[i], callbacks[i]) # line 1
loss = mean(residuals) + mean(jac_eta * p * jac_eta') # line 2
end
return loss
end
grad = Zygote.gradient(() -> Obj(x, y, p, times, callbacks), Flux.params(ann)) # error mutating arrays
removing line 1 and changing line 2 to loss = mean(residuals)
runs fine, but calculation of the jacobian results in the mutating arrays error in zygote.
Is there someone working on implementing the above comment by Chris, or is there some way I can help on this? I'm not that experienced with working on Zygote code but trying to solve the above issue.
https://github.com/FluxML/Zygote.jl/pull/77 is a solution that could be used.
So out of #77 we would just need the @adjoint ∇getindex
part to circumvent the setindex
call is that correct?
Yes
This is my take at Keno's approach
∇getindex(x::AbstractArray, inds) = dy -> (_zerosetindex(x, inds, dy), map(_->nothing, inds)...)
function _zerosetindex(x, inds::NTuple{<:Any, Integer}, dy)
dx = _zero(x, typeof(dy))
dx[inds...] = dy
dx
end
function _zerosetindex(x, inds, dy)
dx = _zero(x, eltype(dy))
dxv = view(dx, inds...)
dxv .= accum.(dxv, _droplike(dy, dxv))
dx
end
@adjoint function _zerosetindex(x, inds, dy)
_zerosetindex(x, inds, dy), ddx -> (nothing, nothing, ddx[inds...])
end
Keno's tests seem to run through as well. Should I put up a PR?
I think that would be great!
We'll want to test this with GPUs, and check for performance
I'm new to flux/zygote/julia. I'm trying to develop a model that looks something like below.
I have a network that produces two outputs,
A
andB
. The gradient ofA
with respect to the inputs is part of my loss function, along with other terms that depend onB
. I've just summed things here for simplicity, but my actual model produces the same error.The error is
ERROR: LoadError: Mutating arrays is not supported
, which comes from thex -> (A=x[1, :], B=x[2:end, :])
line in the network, but I don't understand where the mutation is coming from.I gather from a number of issues here and threads on Discourse that higher-order gradients are not well supported, but there isn't much documentation around this. As a new user, it would be extremely helpful if there was some kind of documentation / guidance about how to work around this.
On a related Discourse thread @ChrisRackauckas suggested using another AD, like ReverseDiff, but I'm can't figure out how to get the gradient that I want. Any guidance would be appreciated.