luiarthur / TuringBnpBenchmarks

Benchmarks of Bayesian Nonparametric models in Turing and other PPLs
https://luiarthur.github.io/TuringBnpBenchmarks/
MIT License
29 stars 1 forks source link

Error in gradient computation during StatsFuns benchmarks #8

Open luiarthur opened 4 years ago

luiarthur commented 4 years ago

I'm benchmarking logsumexp and normlogpdf some StatsFuns. But I am running into errors when doing the gradient computation for normlogpdf. Here's some code for reproducing the error.

Here's the environment.

# Project.toml (Julia v1.4.1)

[deps]
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"

This code snippet shows that somehow I cannot compute the gradients of normlogpdf with respect to the location parameter.

using StatsFuns
using Flux

# Define variables
x, loc, scale = 0.0, 2.0, 1.0

# My implementation of log density of Normal(location, scale), evaluated at x
function my_normlopdf(loc, scale, x)
    z = (x - loc) / scale
    return -z * z * 0.5 - 0.5 * log(2 * pi * scale * scale)
end

# evaluate
my_normlogpdf(loc, scale, x)  # -2.9189385332046727
# gradient
Flux.gradient(mu -> my_normlopdf(mu, scale, x), loc)  # 2.0

# evaluate
normlogpdf(loc, scale, x)  # -2.9189385332046727 (same as above)
# gradient
Flux.gradient(mu -> normlogpdf(mu, scale, x), loc)  # error?!

This is the error being thrown.

ERROR: MethodError: no method matching Irrational{:log2π}(::Int64)
Closest candidates are:
Irrational{:log2π}(::T) where T<:Number at boot.jl:715
Irrational{:log2π}() where sym at irrationals.jl:18
Irrational{:log2π}(::Complex) where T<:Real at complex.jl:37
...
Stacktrace:
[1] convert(::Type{Irrational{:log2π}}, ::Int64) at ./number.jl:7
[2] one(::Type{Irrational{:log2π}}) at ./number.jl:276
[3] one(::Irrational{:log2π}) at ./number.jl:277
[4] (::Zygote.var"#603#604"{Float64,Irrational{:log2π}})(::Float64) at /home/ubuntu/.julia/packages/Zygote/YeCEW/src/lib/number.jl:29
[5] (::Zygote.var"#1590#back#605"{Zygote.var"#603#604"{Float64,Irrational{:log2π}}})(::Float64) at /home/ubuntu/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
[6] normlogpdf at /home/ubuntu/.julia/packages/StatsFuns/CXyCV/src/distrs/norm.jl:29 [inlined]
[7] (::typeof(∂(normlogpdf)))(::Float64) at /home/ubuntu/.julia/packages/Zygote/YeCEW/src/compiler/interface2.jl:0
[8] normlogpdf at /home/ubuntu/.julia/packages/StatsFuns/CXyCV/src/distrs/norm.jl:41 [inlined]
[9] (::typeof(∂(normlogpdf)))(::Float64) at /home/ubuntu/.julia/packages/Zygote/YeCEW/src/compiler/interface2.jl:0
[10] #1754 at /home/ubuntu/.julia/packages/Zygote/YeCEW/src/lib/broadcast.jl:142 [inlined]
[11] #3 at ./generator.jl:36 [inlined]
[12] iterate at ./generator.jl:47 [inlined]
[13] collect(::Base.Generator{Base.Iterators.Zip{Tuple{Array{typeof(∂(normlogpdf)),2},Array{Float64,2}}},Base.var"#3#4"{Zygote.var"#1754#1761"}}) at ./array.jl:665
[14] map at ./abstractarray.jl:2154 [inlined]
[15] (::Zygote.var"#1753#1760"{Tuple{Array{Float64,2},Array{Float64,2},Array{Float64,2}},Val{4},Array{typeof(∂(normlogpdf)),2}})(::Array{Float64,2}) at /home/ubuntu/.julia/packages/Zygote/YeCEW/src/lib/broadcast.jl:142
[16] #4425#back at /home/ubuntu/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49 [inlined]
[17] (::Zygote.var"#174#175"{Zygote.var"#4425#back#1764"{Zygote.var"#1753#1760"{Tuple{Array{Float64,2},Array{Float64,2},Array{Float64,2}},Val{4},Array{typeof(∂(normlogpdf)),2}}},Tuple{NTuple{4,Nothing},Tuple{Nothing}}})(::Array{Float64,2}) at /home/ubuntu/.julia/packages/Zygote/YeCEW/src/lib/lib.jl:182
[18] #347#back at /home/ubuntu/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49 [inlined]
[19] broadcasted at ./broadcast.jl:1238 [inlined]
[20] lpdf_gmm_sf at /home/ubuntu/repo/TuringBnpBenchmarks/dev/Benchmark_BnpUtil/benchmark_methods.jl:34 [inlined]
[21] (::typeof(∂(lpdf_gmm_sf)))(::Float64) at /home/ubuntu/.julia/packages/Zygote/YeCEW/src/compiler/interface2.jl:0
[22] #46 at ./REPL[30]:2 [inlined]
[23] (::typeof(∂(#46)))(::Float64) at /home/ubuntu/.julia/packages/Zygote/YeCEW/src/compiler/interface2.jl:0
[24] (::Zygote.var"#49#50"{Zygote.Params,Zygote.Context,typeof(∂(#46))})(::Float64) at /home/ubuntu/.julia/packages/Zygote/YeCEW/src/compiler/interface.jl:179
[25] gradient(::Function, ::Zygote.Params) at /home/ubuntu/.julia/packages/Zygote/YeCEW/src/compiler/interface.jl:55
[26] top-level scope at REPL[30]:1

A little confused because errors aren't thrown when I use normlogpdf in a Turing model with an AD-based inference algorithm.

trappmartin commented 4 years ago

I'll have a look at it.