dfdx / Yota.jl

Reverse-mode automatic differentiation in Julia
MIT License
158 stars 12 forks source link

rrule piracy #95

Closed jw3126 closed 3 years ago

jw3126 commented 3 years ago
using Zygote, Yota
f(x) = sum(sin.(max.(x,0,x)))
Zygote.gradient(f, [1])

gives me

ERROR: LoadError: MethodError: no method matching length(::Nothing)
Closest candidates are:
  length(::Union{Base.KeySet, Base.ValueIterator}) at abstractdict.jl:58
  length(::Union{LinearAlgebra.Adjoint{T, var"#s6"} where var"#s6"<:Union{StaticArrays.StaticVect
or{var"#s1", T} where var"#s1", StaticArrays.StaticMatrix{var"#s4", var"#s5", T} where {var"#s4",
 var"#s5"}}, LinearAlgebra.Diagonal{T, var"#s13"} where var"#s13"<:(StaticArrays.StaticVector{var
"#s14", T} where var"#s14"), LinearAlgebra.Hermitian{T, var"#s10"} where var"#s10"<:(StaticArrays
.StaticMatrix{var"#s11", var"#s12", T} where {var"#s11", var"#s12"}), LinearAlgebra.LowerTriangul
ar{T, var"#s18"} where var"#s18"<:(StaticArrays.StaticMatrix{var"#s19", var"#s20", T} where {var"
#s19", var"#s20"}), LinearAlgebra.Symmetric{T, var"#s7"} where var"#s7"<:(StaticArrays.StaticMatr
ix{var"#s8", var"#s9", T} where {var"#s8", var"#s9"}), LinearAlgebra.Transpose{T, var"#s1"} where
 var"#s1"<:Union{StaticArrays.StaticVector{var"#s1", T} where var"#s1", StaticArrays.StaticMatrix
{var"#s4", var"#s5", T} where {var"#s4", var"#s5"}}, LinearAlgebra.UnitLowerTriangular{T, var"#s2
4"} where var"#s24"<:(StaticArrays.StaticMatrix{var"#s25", var"#s26", T} where {var"#s25", var"#s
26"}), LinearAlgebra.UnitUpperTriangular{T, var"#s21"} where var"#s21"<:(StaticArrays.StaticMatri
x{var"#s22", var"#s23", T} where {var"#s22", var"#s23"}), LinearAlgebra.UpperTriangular{T, var"#s
15"} where var"#s15"<:(StaticArrays.StaticMatrix{var"#s16", var"#s17", T} where {var"#s16", var"#
s17"}), StaticArrays.StaticVector{var"#s26", T} where var"#s26", StaticArrays.StaticMatrix{var"#s
5", var"#s4", T} where {var"#s5", var"#s4"}, StaticArrays.StaticArray{var"#s26", T, N} where {var
"#s26"<:Tuple, N}} where T) at /home/jan/.julia/packages/StaticArrays/0yhGP/src/abstractarray.jl:
1
  length(::Union{LinearAlgebra.Adjoint{T, S}, LinearAlgebra.Transpose{T, S}} where {T, S}) at /bu
ildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/adjtrans.jl:
195
  ...
Stacktrace:
  [1] unzip(tuples::Vector{Nothing})
    @ Yota ~/.julia/packages/Yota/W2ajF/src/chainrules.jl:38
  [2] rrule(::typeof(Base.Broadcast.broadcasted), ::typeof(max), ::Vector{Int64}, ::Int64, ::Vect
or{Int64})
    @ Yota ~/.julia/packages/Yota/W2ajF/src/chainrules.jl:44
  [3] rrule(::Zygote.ZygoteRuleConfig{Zygote.Context}, ::Function, ::Function, ::Vector{Int64}, :
:Int64, ::Vector{Int64})
    @ ChainRulesCore ~/.julia/packages/ChainRulesCore/GciYT/src/rules.jl:134
  [4] chain_rrule
    @ ~/.julia/packages/Zygote/0da6K/src/compiler/chainrules.jl:103 [inlined]
  [5] macro expansion
    @ ~/.julia/packages/Zygote/0da6K/src/compiler/interface2.jl:0 [inlined]
  [6] _pullback(::Zygote.Context, ::typeof(Base.Broadcast.broadcasted), ::typeof(max), ::Vector{I
nt64}, ::Int64, ::Vector{Int64})
    @ Zygote ~/.julia/packages/Zygote/0da6K/src/compiler/interface2.jl:9
  [7] _pullback
    @ ~/projects/Yota/mwe.jl:3 [inlined]
  [8] _pullback(ctx::Zygote.Context, f::typeof(Main.MWE.f), args::Vector{Int64})
    @ Zygote ~/.julia/packages/Zygote/0da6K/src/compiler/interface2.jl:0
  [9] _pullback(f::Function, args::Vector{Int64})
    @ Zygote ~/.julia/packages/Zygote/0da6K/src/compiler/interface.jl:34
 [10] pullback(f::Function, args::Vector{Int64})
    @ Zygote ~/.julia/packages/Zygote/0da6K/src/compiler/interface.jl:40
 [11] gradient(f::Function, args::Vector{Int64})
    @ Zygote ~/.julia/packages/Zygote/0da6K/src/compiler/interface.jl:58
 [12] top-level scope
    @ ~/projects/Yota/mwe.jl:4
 [13] include(fname::String)
    @ Base.MainInclude ./client.jl:444
 [14] top-level scope
    @ REPL[1]:1
in expression starting at /home/jan/projects/Yota/mwe.jl:1

The problem seems to be this piracy https://github.com/dfdx/Yota.jl/blob/02c77c7c666845fc032633effb24be085026410e/src/chainrules.jl#L43

dfdx commented 3 years ago

Ah, this rule was experimental and shouldn't have done it into the release, sorry for inconvenience. #96 will fix it.