probcomp / Gen.jl

A general-purpose probabilistic programming system with programmable inference
https://gen.dev
Apache License 2.0
1.79k stars 160 forks source link

Bug in `logpdf_grad` implementation of `@dist` DSL distributions. #496

Closed ztangent closed 1 year ago

ztangent commented 1 year ago

This happens with both WithLabelDistributions and RelabeledDistributions.

Minimal example:

@dist labeled_uniform(labels) = labels[uniform_discrete(1, length(labels))]
logpdf_grad(labeled_uniform, :a, [:a, :b, :c])

@dist relabeled_uniform() = [:c, :d, :f][uniform_discrete(1, 3)]
logpdf_grad(relabeled_uniform, :c)

Associated errors:

ERROR: BoundsError: attempt to access 2-element Vector{Nothing} at index [3]
Stacktrace:
 [1] getindex
   @ .\array.jl:861 [inlined]
 [2] logpdf_grad(::Gen.RelabeledDistribution{Symbol, Int64}, ::Symbol, ::Int64, ::Int64)
   @ Gen ...\.julia\packages\Gen\Dne3u\src\modeling_library\dist_dsl\relabeled_distribution.jl:83
 [3] logpdf_grad(::Gen.CompiledDistWithArgs{Symbol}, ::Symbol)
   @ Gen ...\.julia\packages\Gen\Dne3u\src\modeling_library\dist_dsl\dist_dsl.jl:73
 [4] top-level scope

ERROR: BoundsError: attempt to access 2-element Vector{Nothing} at index [3]
 [1] getindex
   @ .\array.jl:861 [inlined]
 [2] logpdf_grad(::Gen.WithLabelArg{Any, Int64}, ::Symbol, ::Vector{Symbol}, ::Int64, ::Int64)
   @ Gen ...\.julia\packages\Gen\Dne3u\src\modeling_library\dist_dsl\relabeled_distribution.jl:29
 [3] logpdf_grad(d::Gen.CompiledDistWithArgs{Any}, x::Symbol, args::Vector{Symbol})
   @ Gen ...\.julia\packages\Gen\Dne3u\src\modeling_library\dist_dsl\dist_dsl.jl:73
 [4] top-level scope

The fix is relatively straightforward -- just an indexing error. Can get to it sometime soon.

Things appear to work fine for TransformedDistributions, except in the case where there are no arguments:

@dist shifted_normal(mu, sigma) = Gen.normal(mu, sigma) + 1.0
logpdf_grad(shifted_normal, 0.0, 0.0, 1.0)

@dist shifted_std_normal() = Gen.normal(0.0, 1.0) + 1.0
logpdf_grad(shifted_std_normal, 0.0)

Calling logpdf_grad on shifted_std_normal leads to the following error + backtrace:

ERROR: MethodError: zero(::Type{Union{}}) is ambiguous. Candidates:
...
Stacktrace:
  [1] track(x::Vector{Union{}}, ::Type{Union{}}, tp::Vector{ReverseDiff.AbstractInstruction})
    @ ReverseDiff ....julia\packages\ReverseDiff\YkVxM\src\tracked.jl:473
  [2] ReverseDiff.GradientConfig(input::Vector{Union{}}, ::Type{Union{}}, tp::Vector{ReverseDiff.AbstractInstruction})
    @ ReverseDiff ...\.julia\packages\ReverseDiff\YkVxM\src\api\Config.jl:50
  [3] ReverseDiff.GradientConfig(input::Vector{Union{}}, tp::Vector{ReverseDiff.AbstractInstruction}) (repeats 2 times)
    @ ReverseDiff ...\.julia\packages\ReverseDiff\YkVxM\src\api\Config.jl:35
  [4] gradient(f::Function, input::Vector{Union{}})
    @ ReverseDiff ...\.julia\packages\ReverseDiff\YkVxM\src\api\gradients.jl:22
  [5] (::Gen.var"#46#52"{Vector{Union{}}})(::Tuple{Int64, Float64})
    @ Gen .\none:0
  [6] iterate
    @ .\generator.jl:47 [inlined]
  [7] grow_to!
    @ .\array.jl:797 [inlined]
  [8] collect(itr::Base.Generator{Base.Iterators.Filter{Gen.var"#48#54"{Vector{Bool}}, Base.Iterators.Enumerate{Tuple{Float64, Float64}}}, Gen.var"#46#52"{Vector{Union{}}}})
    @ Base .\array.jl:721
  [9] logpdf_grad(::Gen.CompiledDistWithArgs{Float64}, ::Float64)
    @ Gen ...\.julia\packages\Gen\Dne3u\src\modeling_library\dist_dsl\dist_dsl.jl:78
 [10] top-level scope

It's less immediately obvious to me how to fix this one, but I presume we can just write special case for when there are zero arguments.

ztangent commented 1 year ago

Resolved by #497 .