dfdx / Yota.jl

Reverse-mode automatic differentiation in Julia
MIT License
158 stars 12 forks source link

MethodError: no method matching length(::Type{Val{2}}) when differentiating log-likelihood #121

Closed ForceBru closed 2 years ago

ForceBru commented 2 years ago

Code

import Random
import Yota

normal_pdf(x, mean, var) =
    exp(-(x - mean)^2 / (2var)) / sqrt(2π * var)

rng = Random.MersenneTwister(42)
data = randn(rng, 100)

Yota.grad(
    mu -> sum(log, normal_pdf.(data, mu, 1.0)),
    1.0
)

Error message

julia> include("yota_err.jl")
ERROR: LoadError: MethodError: no method matching length(::Type{Val{2}})
Closest candidates are:
  length(::Union{Base.KeySet, Base.ValueIterator}) at abstractdict.jl:58
  length(::Union{ArrayInterfaceCore.BidiagonalIndex, ArrayInterfaceCore.TridiagonalIndex}) at ~/.julia/packages/ArrayInterfaceCore/7kMjZ/src/ArrayInterfaceCore.jl:594
  length(::Union{LinearAlgebra.Adjoint{T, <:Union{StaticArraysCore.StaticArray{Tuple{var"#s2"}, T, 1} where var"#s2", StaticArraysCore.StaticArray{Tuple{var"#s3", var"#s4"}, T, 2} where {var"#s3", var"#s4"}}}, LinearAlgebra.Diagonal{T, <:StaticArraysCore.StaticArray{Tuple{var"#s13"}, T, 1} where var"#s13"}, LinearAlgebra.Hermitian{T, <:StaticArraysCore.StaticArray{Tuple{var"#s10", var"#s11"}, T, 2} where {var"#s10", var"#s11"}}, LinearAlgebra.LowerTriangular{T, <:StaticArraysCore.StaticArray{Tuple{var"#s18", var"#s19"}, T, 2} where {var"#s18", var"#s19"}}, LinearAlgebra.Symmetric{T, <:StaticArraysCore.StaticArray{Tuple{var"#s7", var"#s8"}, T, 2} where {var"#s7", var"#s8"}}, LinearAlgebra.Transpose{T, <:Union{StaticArraysCore.StaticArray{Tuple{var"#s2"}, T, 1} where var"#s2", StaticArraysCore.StaticArray{Tuple{var"#s3", var"#s4"}, T, 2} where {var"#s3", var"#s4"}}}, LinearAlgebra.UnitLowerTriangular{T, <:StaticArraysCore.StaticArray{Tuple{var"#s24", var"#s25"}, T, 2} where {var"#s24", var"#s25"}}, LinearAlgebra.UnitUpperTriangular{T, <:StaticArraysCore.StaticArray{Tuple{var"#s21", var"#s22"}, T, 2} where {var"#s21", var"#s22"}}, LinearAlgebra.UpperTriangular{T, <:StaticArraysCore.StaticArray{Tuple{var"#s15", var"#s16"}, T, 2} where {var"#s15", var"#s16"}}, StaticArraysCore.StaticArray{Tuple{var"#s25"}, T, 1} where var"#s25", StaticArraysCore.StaticArray{Tuple{var"#s1", var"#s3"}, T, 2} where {var"#s1", var"#s3"}, StaticArraysCore.StaticArray{<:Tuple, T}} where T) at ~/.julia/packages/StaticArrays/8Dz3j/src/abstractarray.jl:1
  ...
Stacktrace:
  [1] unzip(tuples::Tuple{DataType, ChainRules.var"#apply_type_pullback#42"{Tuple{Int64}}})
    @ Yota ~/.julia/packages/Yota/VCIzN/src/rulesets.jl:92
  [2] bcast_rrule(::Yota.YotaRuleConfig, ::typeof(Base.Broadcast.broadcasted), ::typeof(Core.apply_type), ::Type, ::Vararg{Any}; kw::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ Yota ~/.julia/packages/Yota/VCIzN/src/cr_api.jl:49
  [3] bcast_rrule(::Yota.YotaRuleConfig, ::typeof(Base.Broadcast.broadcasted), ::typeof(Core.apply_type), ::Type, ::Vararg{Any})
    @ Yota ~/.julia/packages/Yota/VCIzN/src/cr_api.jl:48
  [4] mkcall(::Function, ::Yota.YotaRuleConfig, ::Vararg{Any}; val::Missing, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ Umlaut ~/.julia/packages/Umlaut/0uyNC/src/tape.jl:194
  [5] mkcall(::Function, ::Yota.YotaRuleConfig, ::Vararg{Any})
    @ Umlaut ~/.julia/packages/Umlaut/0uyNC/src/tape.jl:179
  [6] record_or_recurse!(::Umlaut.Tracer{Yota.BcastGradCtx}, ::Function, ::Vararg{Any})
    @ Yota ~/.julia/packages/Yota/VCIzN/src/grad.jl:85
  [7] trace!(::Umlaut.Tracer{Yota.BcastGradCtx}, ::Core.CodeInfo, ::Umlaut.Variable, ::Vararg{Umlaut.Variable})
    @ Umlaut ~/.julia/packages/Umlaut/0uyNC/src/trace.jl:220
  [8] trace(::Function, ::Vector{Float64}, ::Vararg{Any}; ctx::Yota.BcastGradCtx, fargtypes::Tuple{typeof(normal_pdf), Tuple{DataType, DataType, DataType}}, deprecated_kws::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ Umlaut ~/.julia/packages/Umlaut/0uyNC/src/trace.jl:346
  [9] make_rrule(::typeof(Base.Broadcast.broadcasted), ::Function, ::Vector{Float64}, ::Vararg{Any})
    @ Yota ~/.julia/packages/Yota/VCIzN/src/cr_api.jl:136
 [10] rrule_via_ad(::Yota.YotaRuleConfig, ::Function, ::Function, ::Vararg{Any})
    @ Yota ~/.julia/packages/Yota/VCIzN/src/cr_api.jl:170
 [11] rrule(::Yota.YotaRuleConfig, ::typeof(Base.Broadcast.broadcasted), ::typeof(normal_pdf), ::Vector{Float64}, ::Float64, ::Float64)
    @ Yota ~/.julia/packages/Yota/VCIzN/src/rulesets.jl:98
 [12] mkcall(::Function, ::Yota.YotaRuleConfig, ::Vararg{Any}; val::Missing, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ Umlaut ~/.julia/packages/Umlaut/0uyNC/src/tape.jl:194
 [13] mkcall(::Function, ::Yota.YotaRuleConfig, ::Vararg{Any})
    @ Umlaut ~/.julia/packages/Umlaut/0uyNC/src/tape.jl:179
 [14] record_primitive!(::Umlaut.Tape{Yota.GradCtx}, ::Function, ::Vararg{Any})
    @ Yota ~/.julia/packages/Yota/VCIzN/src/grad.jl:49
 [15] record_or_recurse!(::Umlaut.Tracer{Yota.GradCtx}, ::Function, ::Vararg{Any})
    @ Umlaut ~/.julia/packages/Umlaut/0uyNC/src/trace.jl:193
 [16] trace!(::Umlaut.Tracer{Yota.GradCtx}, ::Core.CodeInfo, ::Umlaut.Variable, ::Vararg{Umlaut.Variable})
    @ Umlaut ~/.julia/packages/Umlaut/0uyNC/src/trace.jl:220
 [17] trace(f::Function, args::Float64; ctx::Yota.GradCtx, fargtypes::Nothing, deprecated_kws::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ Umlaut ~/.julia/packages/Umlaut/0uyNC/src/trace.jl:346
 [18] #gradtape#90
    @ ~/.julia/packages/Yota/VCIzN/src/grad.jl:243 [inlined]
 [19] grad(f::var"#101#102", args::Float64; seed::Int64)
    @ Yota ~/.julia/packages/Yota/VCIzN/src/grad.jl:314
 [20] grad(f::var"#101#102", args::Float64)
    @ Yota ~/.julia/packages/Yota/VCIzN/src/grad.jl:306
 [21] top-level scope
    @ ~/test/autodiff_bench/yota_err.jl:12
 [22] include(fname::String)
    @ Base.MainInclude ./client.jl:476
 [23] top-level scope
    @ REPL[36]:1
in expression starting at /Users/forcebru/test/autodiff_bench/yota_err.jl:12

Versions

mcabbott commented 2 years ago

FWIW the error is slightly different on master:

```julia julia> import Yota julia> normal_pdf(x, mean, var) = exp(-(x - mean)^2 / (2var)) / sqrt(2π * var); julia> Yota.grad((x, mu) -> sum(log, normal_pdf.(x, mu, 1.0)), rand(10), 1.0) ┌ Error: Failed to compile rrule for broadcasted(normal_pdf, [0.8031553730805592, 0.3509560552123825, 0.032551822966513155, 0.21170603638555385, 0.2049628078398853, 0.8469464153815023, 0.4220217037413583, 0.286621858419426, 0.0338940405059448, 0.43421951685956195], 1.0, 1.0), extract details via: │ (f, args) = Yota.RRULE_VIA_AD_STATE[] └ @ Yota ~/.julia/packages/Yota/QGPcM/src/cr_api.jl:179 ERROR: MethodError: no method matching iterate(::Type{Val}) Closest candidates are: iterate(::Union{LinRange, StepRangeLen}) @ Base range.jl:869 iterate(::Union{LinRange, StepRangeLen}, ::Integer) @ Base range.jl:869 iterate(::T) where T<:Union{Base.KeySet{<:Any, <:Dict}, Base.ValueIterator{<:Dict}} @ Base dict.jl:698 ... Stacktrace: [1] first(itr::Type) @ Base ./abstractarray.jl:436 [2] map(f::typeof(first), t::Tuple{UnionAll, Int64}) @ Base ./tuple.jl:274 [3] trace_call!(::Umlaut.Tracer{Yota.BcastGradCtx}, ::Function, ::Vararg{Any}) @ Yota ~/.julia/packages/Yota/QGPcM/src/grad.jl:92 [4] trace_block!(t::Umlaut.Tracer{Yota.BcastGradCtx}, ir::Core.Compiler.IRCode, bi::Int64, prev_bi::Int64, sparams::Core.SimpleVector) @ Umlaut ~/.julia/packages/Umlaut/LH23t/src/trace.jl:312 [5] trace!(t::Umlaut.Tracer{Yota.BcastGradCtx}, v_fargs::Vector{Umlaut.Variable}) @ Umlaut ~/.julia/packages/Umlaut/LH23t/src/trace.jl:436 [6] trace(::Function, ::Vector{Float64}, ::Vararg{Any}; ctx::Yota.BcastGradCtx, deprecated_kws::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}) @ Umlaut ~/.julia/packages/Umlaut/LH23t/src/trace.jl:546 [7] make_rrule(::typeof(Base.Broadcast.broadcasted), ::Function, ::Vector{Float64}, ::Vararg{Any}) @ Yota ~/.julia/packages/Yota/QGPcM/src/cr_api.jl:136 [8] rrule_via_ad(::Yota.YotaRuleConfig, ::Function, ::Function, ::Vararg{Any}) @ Yota ~/.julia/packages/Yota/QGPcM/src/cr_api.jl:172 [9] rrule(::Yota.YotaRuleConfig, ::typeof(Base.Broadcast.broadcasted), ::typeof(normal_pdf), ::Vector{Float64}, ::Float64, ::Float64) @ Yota ~/.julia/packages/Yota/QGPcM/src/rulesets.jl:91 [10] mkcall(::Function, ::Yota.YotaRuleConfig, ::Vararg{Any}; val::Missing, line::Core.LineInfoNode, kwargs::NamedTuple{(), Tuple{}}, free_kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}) @ Umlaut ~/.julia/packages/Umlaut/LH23t/src/tape.jl:202 [11] record_primitive!(::Umlaut.Tape{Yota.GradCtx}, ::Function, ::Vararg{Any}) @ Yota ~/.julia/packages/Yota/QGPcM/src/grad.jl:54 [12] trace_call!(::Umlaut.Tracer{Yota.GradCtx}, ::Function, ::Vararg{Any}) @ Umlaut ~/.julia/packages/Umlaut/LH23t/src/trace.jl:286 [13] trace_block!(t::Umlaut.Tracer{Yota.GradCtx}, ir::Core.Compiler.IRCode, bi::Int64, prev_bi::Int64, sparams::Core.SimpleVector) @ Umlaut ~/.julia/packages/Umlaut/LH23t/src/trace.jl:312 [14] trace!(t::Umlaut.Tracer{Yota.GradCtx}, v_fargs::Vector{Umlaut.Variable}) @ Umlaut ~/.julia/packages/Umlaut/LH23t/src/trace.jl:436 [15] trace(::Function, ::Vector{Float64}, ::Vararg{Any}; ctx::Yota.GradCtx, deprecated_kws::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}) @ Umlaut ~/.julia/packages/Umlaut/LH23t/src/trace.jl:546 [16] gradtape(::var"#3#4", ::Vector{Float64}, ::Vararg{Any}; ctx::Yota.GradCtx, seed::Int64) @ Yota ~/.julia/packages/Yota/QGPcM/src/grad.jl:258 [17] grad(::var"#3#4", ::Vector{Float64}, ::Vararg{Any}; seed::Int64) @ Yota ~/.julia/packages/Yota/QGPcM/src/grad.jl:334 [18] grad(::var"#3#4", ::Vector{Float64}, ::Vararg{Any}) @ Yota ~/.julia/packages/Yota/QGPcM/src/grad.jl:326 [19] top-level scope @ REPL[5]:1 (jl_atLVz9) pkg> st Status `/private/var/folders/yq/4p2zwd614y59gszh7y9ypyhh0000gn/T/jl_atLVz9/Project.toml` [92992a2b] Umlaut v0.4.2 `https://github.com/dfdx/Umlaut.jl.git#main` [cd998857] Yota v0.7.5 `https://github.com/dfdx/Yota.jl.git#main` ```

If I disable all rules related to broadcasting in Yota, then the error is more straightforward, and I think tells us that the rules from https://github.com/JuliaDiff/ChainRules.jl/pull/644 (which ought to cover this kind of broadcasting) are not being called:

julia> using Yota

julia> normal_pdf(x, mean, var) = exp(-(x - mean)^2 / (2var)) / sqrt(2π * var);

julia> Yota.grad((x, mu) -> sum(log, normal_pdf.(x, mu, 1.0)), rand(10), 1.0)
ERROR: No deriative rule found for op %5 = materialize(%4)::Vector{Float64} , try defining it using 

    ChainRulesCore.rrule(::typeof(Base.Broadcast.materialize), ::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(normal_pdf), Tuple{Vector{Float64}, Float64, Float64}}) = ...

Stacktrace:
 [1] error(s::String)
   @ Base ./error.jl:35
 [2] step_back!(tape::Umlaut.Tape{Yota.GradCtx}, y::Umlaut.Variable)
   @ Yota ~/.julia/dev/Yota/src/grad.jl:185
...

# same error on a much easier broadcast, should use derivatives_given_output

julia> Yota.grad((x, mu) -> sum(log, atan.(x, mu)), rand(10), 1.0)
ERROR: No deriative rule found for op %5 = materialize(%4)::Vector{Float64} , try defining it using ...

# and an even easier one, has its own rrule(broadcasted, +, ...)

julia> Yota.grad((x, mu) -> sum(log, x .+ mu), rand(10), 1.0)
ERROR: No deriative rule found for op %7 = materialize(%5)::Vector{Float64} , try defining it using ...

(jl_fINTq0) pkg> st
Status `/private/var/folders/yq/4p2zwd614y59gszh7y9ypyhh0000gn/T/jl_fINTq0/Project.toml`
  [082447d4] ChainRules v1.43.2
  [92992a2b] Umlaut v0.4.2 `https://github.com/dfdx/Umlaut.jl.git#main`
  [cd998857] Yota v0.7.5 `~/.julia/dev/Yota`
dfdx commented 2 years ago

Oh, I think I messed up broadcasting on main recently :( Right now I'm fixing a possibly related bug, so let's see how it works after that (ETA ~1-2 days).

dfdx commented 2 years ago

@mcabbott If I understand it correctly, JuliaDiff/ChainRules.jl#644 doesn't provide an rrule(materialize, ...), so if you commented it out from Yota too (as I just did), it makes sense that Yota started to complain about the missing rule. At least, when I uncommented the rule, this example started to work (on fix-kw-rrule branch):

julia> grad((x, mu) -> sum(log, x .+ mu), rand(10), 1.0)
(2.515783933810798, (ZeroTangent(), [0.5835340506097648, 0.6717831140646412, 0.9651554343975317, 0.5720477538411006, 0.9756415208897008, 0.7392737693342455, 0.9473341524270764, 0.9914851197981195, 0.9458598901263214, 0.5826024579088742], 7.974717263397377))

rrule(broadcasted, normal_pdf, ...) still doesn't hit the new generic broadcasting though. Is it correct that I need to invoke the signature with ::BroadcastStyle to make this work?

mcabbott commented 2 years ago

Oh right, sorry, the rule for materialize is indeed still needed. And yes, I think Yota somehow needs to fall through to the signature with ::BroadcastStyle (and more Refs).

dfdx commented 2 years ago

The fix is now on main, featuring the new generic broadcasting from https://github.com/JuliaDiff/ChainRules.jl/pull/644

mcabbott commented 2 years ago

Thanks! Are you thinking of making a release sometime soon?

dfdx commented 2 years ago

@mcabbott Yep, tagged v0.7.4.