TuringLang / Bijectors.jl

Implementation of normalising flows and constrained random variable transformations
https://turinglang.org/Bijectors.jl/
MIT License
200 stars 33 forks source link

Domain Error for VecCholeskyBijector bijector when calling logabsdetjac #279

Open paschermayr opened 1 year ago

paschermayr commented 1 year ago

Hi there,

Thanks for all the hard work and for updating all the transforms!

I noticed I get a domain error when evaluating 'logabsdetjac' after mapping the parameter in the unconstrained dimension back via the inverse bijector of the LKJ Cholesky transform (MWE below).

Would it be possible to return a NaN or -Inf instead of throwing that error? I use this bijector in a larger code base and this sometimes stops a loop, I assume this would be similar during Turing MCMC iterations?

using Bijectors     #v0.10.35
using Distributions #v0.25.98
using ForwardDiff   #v0.10.35
using ReverseDiff   #v1.15.0

θ_unconstrained = [
    -1.9887091960524537,
    -13.499454444466279,
    -0.39328331954134665,
    -4.426097270849902,
    13.101175413857023,
    7.66647404712346,
    9.249285786544894,
    4.714877413573335,
    6.233118490809442,
    22.28264809311481
]
n = 5
nparam = binomial(n, 2)
d = LKJCholesky(n, 10)
b = Bijectors.bijector(d)
b_inv = inverse(b)

θ = b_inv(θ_unconstrained)
Bijectors.logabsdetjac(b, θ) #ERROR: DomainError with -1.2425014227268605e-5:
yebai commented 1 year ago

@paschermayr Thanks for reporting this. Maybe create a PR with some tests?

torfjelde commented 1 year ago

Would it be possible to return a NaN or -Inf instead of throwing that error?

I think this is a bit too drastic to make default behavior, but I wouldn't be opposed to supporting some sort of "mode" for this though, though uncertain exactly where to put.

I use this bijector in a larger code base and this sometimes stops a loop, I assume this would be similar during Turing MCMC iterations?

If you use the bijector directly, as in the above snippet, then it should be fairly easy to just check if you hit a DomainError yourself, and return -Inf in that scenario, no? E.g.

try
    res = Bijectors.logabsdetjac(b, θ)
catch e
    if e isa DomainError
        res = -Inf
    else
        rethrow(e)
    end
end

If you're running into this in a Turing model, then we probably need to implement a evaluation-mode or something as I mentioned above to deal with this.

paschermayr commented 1 year ago

If you use the bijector directly, as in the above snippet, then it should be fairly easy to just check if you hit a DomainError yourself, and return -Inf in that scenario, no? E.g.

Right, this is my current workaround, but a try-catch method would make other libraries (like Zygote) incompatible.

If you're running into this in a Turing model, then we probably need to implement a evaluation-mode or something as I mentioned above to deal with this.

I will have to try it myself, but in Turing, if you just estimate the Covariance Matrix of a Multivariate Normal or some more exotic alternative via LKJCholesky + the Diagonal terms, i.e.: \Sigma = Symmetric( diagm(σ) * ρ.factors * ρ.factors' * diagm(σ) ) would you not see that error too when using NUTS? I assume Bijectors.logabsdetjac is used there too to account for the transformations from unconstrained to constrained space.

@paschermayr Thanks for reporting this. Maybe create a PR with some tests?

Thank you for reaching out! I will have to successfully find out first what exactly is causing the DomainError in the first place.

torfjelde commented 1 year ago

Right, this is my current workaround, but a try-catch method would make other libraries (like Zygote) incompatible.

Ah true :confused:

I assume Bijectors.logabsdetjac is used there too to account for the transformations from unconstrained to constrained space.

Yeah. So when I sent "implement an evaluation-mode ... that deals with this", I meant a "mode" for Turing which would evaluate the models in this way. But this wouldn't avoid the issue of compatibility with something like Zygote, unfortunately :confused:

Likely the best way of dealing with this is just to improve the numerical stability of the transform. Will have a proper look at the example you've posted once I find the time!