Open stelmo opened 3 years ago
Thanks for brining this up! If you remove the need to use params
, you get a more meaningful error.
function loss(m, x, y)
fitloss = Flux.Losses.mse(m(x), y) # typical loss function
derivativeloss = abs2(gradient(a -> m(a)[1], x)[1][3]) # problem source (3rd input is time)
return fitloss + derivativeloss
end
gs = gradient(m, xt, yt) do m, x, y
loss(m, x, y)
end
julia> gs = gradient(m, xt, yt) do m, x, y
loss(m, x, y)
end # this generates a foreigncall exception
ERROR: Mutating arrays is not supported
Stacktrace:
[1] error(::String) at ./error.jl:33
[2] (::Zygote.var"#380#381")(::Nothing) at /Users/dhairyagandhi/Downloads/new_clones/Zygote.jl/src/lib/array.jl:58
[3] (::Zygote.var"#2288#back#382"{Zygote.var"#380#381"})(::Nothing) at /Users/dhairyagandhi/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
[4] (::Zygote.var"#152#153"{Zygote.var"#2288#back#382"{Zygote.var"#380#381"},Tuple{Tuple{Nothing,Nothing},Tuple{Nothing}}})(::Nothing) at /Users/dhairyagandhi/Downloads/new_clones/Zygote.jl/src/lib/lib.jl:191
[5] (::Zygote.var"#1699#back#154"{Zygote.var"#152#153"{Zygote.var"#2288#back#382"{Zygote.var"#380#381"},Tuple{Tuple{Nothing,Nothing},Tuple{Nothing}}}})(::Nothing) at /Users/dhairyagandhi/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
[6] #372 at /Users/dhairyagandhi/Downloads/new_clones/Zygote.jl/src/lib/array.jl:38 [inlined]
The same can be seen using https://github.com/FluxML/Zygote.jl/pull/823 directly as well.
This comes from picking out the first element of the model output, instead if you use something like
julia> function loss(m, x, y)
fitloss = Flux.Losses.mse(m(x), y) # typical loss function
derivativeloss = abs2(gradient(a -> sum(m(a)), x)[1][3]) # problem source (3rd input is time)
return fitloss + derivativeloss
end
(notice the call to sum
instead of the first element), then things work fine.
In cases where we do need the element wise grads, sum
doesn't quite cut it, and we need to write the getindex
adjoint to not do array mutation (which it does for perf). We could always write a fallback rule that handles it as well.
Thank you very much! This fixes the one dimensional case :) I will add this thread to my discourse question (as well as the other question on there).
How would one write the adjoint for the getindex function? I will need it to extend the PINN system to multiple dimensions...
I looked at the source code for Base.getindex
and it is not very clear to me what's going on - hopefully I won't have to change anything there? The documentation of Zygote mentions that one should use ChainRules
to define custom adjoint rules. Should I use that instead to define a fallback rule?
What would the multidimensional case entail?
You're better off checking out the code for the get index adjoint in Zygote in src/lib/array.jl and write an adjoint to that
Essentially I mean element wise gradients e.g.
m = Chain(Dense(5, 10, relu), Dense(10, 10, relu), Dense(10, 2)) # [u0_1, u0_2, k1, k2, t] -> [u1(t), u2(t)]
function loss(m, x, y)
fitloss = Flux.Losses.mse(m(x), y) # typical loss function
derivativeloss = 0.0f0
for i=1:size(x, 2)
for j=1:2 # dimension to take gradient of
derivativeloss = abs2( gradient(a -> m(a)[j], x[:, i])[1][5] ) # ||du_j/dt|| for j=1,2 this mutates again :(
end
end
return fitloss + derivativeloss
end
xt = rand(5, 10)
yt = rand(2, 10)
gs = gradient(m, xt, yt) do m, x, y
loss(m, x, y)
end
Okay, I've looked in Zygote/src/lib/array.jl and found this:
@adjoint getindex(x::AbstractArray, inds...) = x[inds...], ∇getindex(x, inds)
∇getindex(x::AbstractArray, inds) = dy -> begin
if inds isa NTuple{<:Any, Integer}
dx = _zero(x, typeof(dy))
dx[inds...] = dy
else
dx = _zero(x, eltype(dy))
dxv = view(dx, inds...)
dxv .= accum.(dxv, _droplike(dy, dxv))
end
return (dx, map(_->nothing, inds)...)
end
Is the mutation happening here:
dxv = view(dx, inds...)
dxv .= accum.(dxv, _droplike(dy, dxv))
Would I have to change that somehow (I am not sure at all what to do)? I really appreciate your help in this!
Right, anywhere setindex
gets in the way.
Bump.
Zygote.jl
now has a jacobian
function but results in the same mutating array error. I think this is in general very critical for physics informed machine learning, would be great to have some clarity on this! @DhairyaLGandhi I'm not quite sure I get your last comment. Thanks!
This is totally not my wheelhouse, but there are a few threads on discourse about representing PINNs that you could look into.
Essentially I mean element wise gradients e.g.
m = Chain(Dense(5, 10, relu), Dense(10, 10, relu), Dense(10, 2)) # [u0_1, u0_2, k1, k2, t] -> [u1(t), u2(t)] function loss(m, x, y) fitloss = Flux.Losses.mse(m(x), y) # typical loss function derivativeloss = 0.0f0 for i=1:size(x, 2) for j=1:2 # dimension to take gradient of derivativeloss = abs2( gradient(a -> m(a)[j], x[:, i])[1][5] ) # ||du_j/dt|| for j=1,2 this mutates again :( end end return fitloss + derivativeloss end xt = rand(5, 10) yt = rand(2, 10) gs = gradient(m, xt, yt) do m, x, y loss(m, x, y) end
Okay, I've looked in Zygote/src/lib/array.jl and found this:
@adjoint getindex(x::AbstractArray, inds...) = x[inds...], ∇getindex(x, inds) ∇getindex(x::AbstractArray, inds) = dy -> begin if inds isa NTuple{<:Any, Integer} dx = _zero(x, typeof(dy)) dx[inds...] = dy else dx = _zero(x, eltype(dy)) dxv = view(dx, inds...) dxv .= accum.(dxv, _droplike(dy, dxv)) end return (dx, map(_->nothing, inds)...) end
Is the mutation happening here:
dxv = view(dx, inds...) dxv .= accum.(dxv, _droplike(dy, dxv))
Would I have to change that somehow (I am not sure at all what to do)? I really appreciate your help in this!
This code isnt correct right? I'm getting error, at the gradient part.
Hi, I am trying to implement a PINN as described here using Flux. Essentially, I am trying to train a neural network that includes the time derivative of it in the loss function (time is one of its inputs). Below is a very minimal example:
This issue seems to be pervasive, see here and #1338 and #1257 and here (the last one is me on the discourse channel). I have tried all the suggestions in the aforementioned links, but nothing seems to work. Do you have a work around or is this some built in limitation of Flux/Zygote?