Open mcabbott opened 4 years ago
gradient(x -> sum(setindex!!(x, 2, 2)), zeros(3))
would not work with current BangBang since it just uses standard setindex!
.
If you only need to differentiate setindex of vectors and matrices, a straightforward approach may be to use Base.setindex
after importing ArrayInterface.jl
(which pirates Base.setindex
).
Re: gradient(x -> sum(setindex!!(x, 2, 2)), SA[0,0,0])
it seems that you can get the same error without BangBang:
julia> using StaticArrays, Zygote
julia> gradient(x -> sum(Base.setindex(x, 2, 2)), SA[0,0,0])
Internal error: encountered unexpected error in runtime:
BoundsError(a=Array{Any, (7,)}[
Core.Compiler.VarState(typ=Zygote.Pullback{Tuple{typeof(StaticArrays._setindex), StaticArrays.Length{3}, StaticArrays.SArray{Tuple{3}, Int64, 1, 3}, Int64, Int64}, Any}, undef=false),
Core.Compiler.VarState(typ=FillArrays.Fill{Int64, 1, Tuple{Base.OneTo{Int64}}}, undef=false),
Core.Compiler.VarState(typ=Core.Compiler.Const(val=nothing, actual=false), undef=true),
Core.Compiler.VarState(typ=Any, undef=false),
Core.Compiler.VarState(typ=Any, undef=false),
Core.Compiler.VarState(typ=Any, undef=false),
Core.Compiler.VarState(typ=Any, undef=false)], i=(8,))
rec_backtrace at /buildworker/worker/package_linux64/build/src/stackwalk.c:94
gradient(x -> sum(setindex!!(x, 2, 2)), zeros(3))
would not work with current BangBang since it just uses standardsetindex!
.
Thanks for having a look! I guess that was my question roughly, whether it could automatically hide the mutation when being called by Zygote.
Re Base.setindex
, indeed this seems to work, just not on StaticArrays.
julia> Zygote.gradient((t,v) -> sum(Base.setindex(t, v, 2)), (0,0,0), 99)
((1, 0, 1), 1)
julia> using ArrayInterface
julia> Zygote.gradient((t,v) -> sum(Base.setindex(t, v, 2)), [0,0,0], 99)
([1, 0, 1], 1)
Are these meant to work?