JuliaFolds / BangBang.jl

Immutables as mutables, mutables as immutables.
MIT License
108 stars 11 forks source link

Errors when used with Zygote #58

Open mcabbott opened 4 years ago

mcabbott commented 4 years ago

Are these meant to work?

julia> using StaticArrays, Zygote, BangBang

julia> setindex!!(SA[0,0,0], 2, 2) # ok!
3-element SArray{Tuple{3},Int64,1,3} with indices SOneTo(3):
 0
 2
 0

julia> gradient(x -> sum(setindex!!(x, 2, 2)), zeros(3))
ERROR: Mutating arrays is not supported

julia> gradient(x -> sum(setindex!!(x, 2, 2)), SA[0,0,0]) # sometimes crashes Julia
Internal error: encountered unexpected error in runtime:
BoundsError(a=Array{Any, (7,)}[
  Core.Compiler.VarState(typ=Zygote.Pullback{Tuple{typeof(StaticArrays._setindex), StaticArrays.Length{3}, StaticArrays.SArray{Tuple{3}, Int64, 1, 3}, Int64, Int64}, Any}, undef=false), 
…

(v1.3) pkg> st BangBang
    Status `~/.julia/environments/v1.3/Project.toml`
  [198e06fe] BangBang v0.3.6

(v1.3) pkg> st Zygote
    Status `~/.julia/environments/v1.3/Project.toml`
  [7a1cc6ca] FFTW v1.1.0
  [1a297f60] FillArrays v0.8.2
  [f6369f11] ForwardDiff v0.10.7
  [1914dd2f] MacroTools v0.5.3
  [e88e6eb3] Zygote v0.4.1
  [9a3f8284] Random 
tkf commented 4 years ago

gradient(x -> sum(setindex!!(x, 2, 2)), zeros(3)) would not work with current BangBang since it just uses standard setindex!.

If you only need to differentiate setindex of vectors and matrices, a straightforward approach may be to use Base.setindex after importing ArrayInterface.jl (which pirates Base.setindex).

Re: gradient(x -> sum(setindex!!(x, 2, 2)), SA[0,0,0]) it seems that you can get the same error without BangBang:

julia> using StaticArrays, Zygote

julia> gradient(x -> sum(Base.setindex(x, 2, 2)), SA[0,0,0])
Internal error: encountered unexpected error in runtime:
BoundsError(a=Array{Any, (7,)}[
  Core.Compiler.VarState(typ=Zygote.Pullback{Tuple{typeof(StaticArrays._setindex), StaticArrays.Length{3}, StaticArrays.SArray{Tuple{3}, Int64, 1, 3}, Int64, Int64}, Any}, undef=false),
  Core.Compiler.VarState(typ=FillArrays.Fill{Int64, 1, Tuple{Base.OneTo{Int64}}}, undef=false),
  Core.Compiler.VarState(typ=Core.Compiler.Const(val=nothing, actual=false), undef=true),
  Core.Compiler.VarState(typ=Any, undef=false),
  Core.Compiler.VarState(typ=Any, undef=false),
  Core.Compiler.VarState(typ=Any, undef=false),
  Core.Compiler.VarState(typ=Any, undef=false)], i=(8,))
rec_backtrace at /buildworker/worker/package_linux64/build/src/stackwalk.c:94
mcabbott commented 4 years ago

gradient(x -> sum(setindex!!(x, 2, 2)), zeros(3)) would not work with current BangBang since it just uses standard setindex!.

Thanks for having a look! I guess that was my question roughly, whether it could automatically hide the mutation when being called by Zygote.

Re Base.setindex, indeed this seems to work, just not on StaticArrays.

julia> Zygote.gradient((t,v) -> sum(Base.setindex(t, v, 2)), (0,0,0), 99)
((1, 0, 1), 1)

julia> using ArrayInterface

julia> Zygote.gradient((t,v) -> sum(Base.setindex(t, v, 2)), [0,0,0], 99)
([1, 0, 1], 1)