TuringLang / AdvancedHMC.jl

Robust, modular and efficient implementation of advanced Hamiltonian Monte Carlo algorithms
https://turinglang.org/AdvancedHMC.jl/
MIT License
229 stars 40 forks source link

Support for complex numbers #262

Closed itsdfish closed 3 years ago

itsdfish commented 3 years ago

Hi @xukai92,

As discussed on Slack, adding support for complex numbers would allow us to use many models from physics. The code below is a quantum model of human judgment based on Wang et al. (2014). The model has a single parameter theta, which rotates the basis vectors. Thank you for looking into this. This feature would be very useful for me and others as well.

function quantum_model(S, θ)
    N = length(S)
    H = zeros(N, N) # Hamiltonian
    idx = CartesianIndex.(2:N,1:(N-1))
    dind = diagind(H)
    H[idx] .= 1.0
    H .+= H'
    H[dind] .= 1:N
    U = exp(-1im * θ * H) 
    na = 1
    PA = zeros(N, N)
    PA[dind] .= [ones(na); zeros(N - na)]
    PnA = I - PA
    Sa = PA * S
    Sa ./= sqrt(Sa' * Sa) #normalized projection A
    Sna = PnA * S; 
    Sna ./= sqrt(Sna' * Sna) # normalized projection B

    nb = 1 # dim of B subspace
    PB = zeros(ComplexF64, N, N)
    PB[dind] = [zeros(nb, 1); ones(N - nb, 1)] # projector for B in B coord
    PB .= U * PB * U' # projector for B in A coord
    PnB = I - PB

    Sb = PB * S
    Sb ./= sqrt(Sb' * Sb)
    Snb = PnB * S
    Snb ./= sqrt(Snb' * Snb)

    pAtB = (S' * PA * S) * (Sa' * PB * Sa) # prob A then B
    pAtnB = (S' * PA * S) * (Sa' * PnB * Sa)
    pnAtB = (S' * PnA * S)*(Sna' * PB * Sna)
    pnAtnB = (S' * PnA * S) * (Sna' * PnB * Sna)
    # pAtB + pAtnB + pnAtB + pnAtnB == 1
    pBtA = (S' * PB * S)*(Sb' * PA * Sb) # prob B then A
    pBtnA = (S' * PB * S) * (Sb' * PnA * Sb)
    pnBtA = (S' * PnB*S) * (Snb' * PA * Snb)
    pnBtnA = (S' * PnB * S) * (Snb' * PnA * Snb)
    # pBtA + pBtnA + pnBtA + pnBtnA == 1
    # order 1
    c_probs1 = [pAtB, pAtnB , pnAtB, pnAtnB]
    # order 2
    c_probs2 = [pBtA, pnBtA, pBtnA, pnBtnA]
    return map(real, c_probs1), map(real, c_probs2)
end

function simulate(S, θ, n_sim)
    p1,p2 = quantum_model(S, θ)
    y1 = rand(Multinomial(n_sim, p1))
    y2 = rand(Multinomial(n_sim, p2))
    return y1, y2
end

using Distributions, Turing, LinearAlgebra
import Distributions: logpdf, loglikelihood

"""
Simplified model based on 
    Wang, Z., Solloway, T., Shiffrin, R. M., & Busemeyer, J. R. (2014). 
    Context effects produced by question orders reveal quantum nature of human 
    judgments. Proceedings of the National Academy of Sciences, 111(26), 9431-9436.

"""
struct Quantum{T1,T2} <: ContinuousUnivariateDistribution
    θ::T1
    S::T2
    n::Int64
end

function logpdf(d::Quantum, data)
    p = quantum_model(d.S, d.θ)
    LL = @. logpdf(Multinomial(d.n, p), data)
    return sum(LL)
end

loglikelihood(d::Quantum, data::Tuple{Vector{Int64}, Vector{Int64}}) = logpdf(d, data)

# number of observations per condition
n_sim = 100
# dimensionality of Hilbert Space
N = 4
# state vector
S = fill(sqrt(.25), N)
# rotation
θ = 2.0
data = simulate(S, θ, n_sim)

@model model(data, S, n_sim) = begin
    θ ~ Truncated(Normal(2, 2), 0.0, Inf)
    data ~ Quantum(θ, S, n_sim)
end

# Settings of the NUTS sampler.
n_samples = 1000
delta = 0.85
n_adapt = 1000
n_chains = 4
specs = NUTS(n_adapt, delta)
# Start sampling.
chain = sample(model(data, S, n_sim), specs, MCMCThreads(), n_samples, n_chains, progress=true)
xukai92 commented 3 years ago

Thanks for the example. It would NOT be difficult support it programmatically but what I don't quite understand is how HMC would work on the complex domain, which I need to look into a bit.

On the practical side, does this (or any other) inference problem with complex numbers has a closed form solution, so that we can also check if the inference goes correct?

itsdfish commented 3 years ago

Unfortunately, I am not very familiar with these models yet. So I don't know if there are any models with closed-form solution. I will ask my colleague next week to see if he has any ideas.

I wonder if it would be possible to modify a simple Binomial model so that it is in a complex domain.

sethaxen commented 3 years ago

The reason the above model does not work has nothing to do with AdvancedHMC.jl. AdvancedHMC does indeed require that all parameters be real, but your only parameter, θ, is real, so this is fine. It does not matter if intermediate quantities are complex (I likewise use models with real parameters, complex intermediates, and real probabilities).

The issue here will be two-fold. First, ForwardDiff has only partial support for complex numbers, so while it might work for some models, expect it to fail for others. In this case, the issue is that your model includes the matrix exponential, whose signature is constrained to StridedMatrixes with eltype of BlasFloat. See the error

        nested task error: MethodError: no method matching exp(::Matrix{Complex{ForwardDiff.Dual{ForwardDiff.Tag{Turing.Core.var"#f#1"{DynamicPPL.TypedVarInfo{NamedTuple{(:θ,), Tuple{DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:θ, Tuple{}}, Int64}, Vector{Truncated{Normal{Float64}, Continuous, Float64}}, Vector{AbstractPPL.VarName{:θ, Tuple{}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}, DynamicPPL.Model{var"#1#2", (:data, :S, :n_sim), (), (), Tuple{Tuple{Vector{Int64}, Vector{Int64}}, Vector{Float64}, Int64}, Tuple{}}, DynamicPPL.Sampler{NUTS{Turing.Core.ForwardDiffAD{40}, (), AdvancedHMC.DiagEuclideanMetric}}, DynamicPPL.DefaultContext}, Float64}, Float64, 1}}})
        Closest candidates are:
          exp(::StridedMatrix{var"#s832"} where var"#s832"<:Union{Float32, Float64, ComplexF32, ComplexF64}) at /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/dense.jl:557
          exp(::StridedMatrix{var"#s832"} where var"#s832"<:Union{Integer, Complex{var"#s831"} where var"#s831"<:Integer}) at /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/dense.jl:558
          exp(::StaticArrays.StaticMatrix{N, M, T} where {N, M, T}) at /Users/saxen/.julia/packages/StaticArrays/rdb0l/src/expm.jl:1

So any AD that uses type overloading will fail, unless they have added a special rule for exp. The solution to both of these problems is to use Zygote, which has the best out-of-the-box AD support for complex numbers and does not use type overloading. So you would add

using Zygote
Turing.setadbackend(:zygote)

to your example. But one thing to keep in mind is that Zygote does not support code that mutates array values, so you would need to either modify your implementation of quantum_model to be non-mutating or work out a ChainRules/ZygoteRules rule for it. I'm happy to advise on this if you run into problems on the #autodiff channel on slack.

On the initial issue of supporting HMC with complex parameters, I don't think this is a good fit for this package. Most models will have a mixture of real and complex parameters, which could not fit well in an array of uniform eltype. Then there's the question of how to handle the metric, whose dimension would not match the number of degrees of freedom. No distributions in Distributions.jl support complex random variables, and the complex distributions I know of can all be trivially constructed in terms of a real distribution. It's straightforward to define complex parameters in terms of real ones, so unless we have a good example where this just won't work, I recommend this package keep to sampling real-valued parameters.

xukai92 commented 3 years ago

Thanks for the input here @sethaxen.

Most models will have a mixture of real and complex parameters, which could not fit well in an array of uniform eltype. Then there's the question of how to handle the metric, whose dimension would not match the number of degrees of freedom.

I agree with you on this issues, which I didn't realise earlier. I was thinking simply support arrays of complex numbers, not mixed ones, as one could just represent reals with complex. Though with this I was not sure if HMC is expected to give correct sampling.

so unless we have a good example where this just won't work, I recommend this package keep to sampling real-valued parameters.

Agreed. Seems a bit too involved to have this properly done.

Also as you suggested, the original problem from @itsdfish is not really sampling on complex domain.

sethaxen commented 3 years ago

Most models will have a mixture of real and complex parameters, which could not fit well in an array of uniform eltype. Then there's the question of how to handle the metric, whose dimension would not match the number of degrees of freedom.

I agree with you on this issues, which I didn't realise earlier. I was thinking simply support arrays of complex numbers, not mixed ones, as one could just represent reals with complex. Though with this I was not sure if HMC is expected to give correct sampling.

Out-of-the-box I would not expect it to work. Since the imaginary components of the real parameters can vary without any change to the log probability, they make the target distribution improper. But e.g. a real diagonal metric would have the real and imaginary parts sharing the same diagonal entry of the metric, so one would be adapting that entry of the metric to two two parameters: one of infinite scale and one of finite scale, which would not go well. To get any reasonable performance, one would have to either indicate which imaginary parts should be zero and constrain them to be zero, or augment the logpdf to place a prior on the imaginary parts, which would not solve the shared scale issue.

itsdfish commented 3 years ago

Thank you for your explanation. The complications that you outlined make it clear why this is a difficult issue, and why changing AdvancedHMC is not a viable option. Another complicating factor is that there appears to be periodicity in this type of model, which propagates to the log likelihood surface (see below). My intuition is that NUTS and perhaps many other algorithms would struggle with this type of geometry.

Are there plans to support mutation in the future? Mutation is one the great features of Julia in my opinion. quantum

xukai92 commented 3 years ago

Another complicating factor is that there appears to be periodicity in this type of model, which propagates to the log likelihood surface (see below).

I guess you could reparameterise theta to get rid of this periodicity.

Are there plans to support mutation in the future? Mutation is one the great features of Julia in my opinion.

You can try https://fluxml.ai/Zygote.jl/latest/utils/#Zygote.Buffer to do mutation in Zygote.

itsdfish commented 3 years ago

Thank you both for your help. I'll go ahead and close this issue.

sethaxen commented 3 years ago

Thank you for your explanation.

No problem!

Another complicating factor is that there appears to be periodicity in this type of model, which propagates to the log likelihood surface (see below). My intuition is that NUTS and perhaps many other algorithms would struggle with this type of geometry.

That's right! Thankfully, as @xukai92 pointed out, this can be fixed with a simple reparameterization. Something like this:

@model model(data, S, n_sim) = begin
    x ~ Normal(0, 1)
    y ~ Normal(0, 1)
    θ = atan(y, x)
    data ~ Quantum(θ, S, n_sim)
end

would create θ uniform in [-pi,pi] but the actual parameters are x and y, so that you don't have to worry about wrapping θ from pi to -pi; it happens automatically. If you want to put a non-uniform prior on θ, say a von Mises distribution, you can do it by manually incrementing the log probability (target).

@model model(data, S, n_sim) = begin
    x ~ Normal(0, 1)
    y ~ Normal(0, 1)
    θ = atan(y, x)
    Turing.@addlogprob! logpdf(VonMises(2, 0.25), θ)
    data ~ Quantum(θ, S, n_sim)
end

This works for reasons outlined in the Stan manual. You just need to be careful that the distribution you apply to θ accounts for its cyclicness. e.g. you could apply Normal(2, 2) (the parameters I've chosen for vonMises should be close to this though, I think), but then you need to account for its wrapping around the circle to infinity, see the Wrapped normal.

Are there plans to support mutation in the future? Mutation is one the great features of Julia in my opinion.

As I understand, there are, but not soon. This is apparently a hard problem that needs an engineer's focus to solve, and I don't completely understand it. For reference, some other AD systems like JAX have the same limitations. Other reverse-mode AD's in Julia like ReverseDiff and Tracker don't, but they also don't support complex numbers well or at all, I think.

itsdfish commented 3 years ago

@sethaxen, this is really helpful. I didn't even think about atan as a function to reparameterize theta. This might be useful even with other samplers or optimization methods.

Understandable about Zygote. From what I can tell, developing AD software is very challenging. It is certainly a stress test for the language. I hope that one day one of the packages can approach the performance of Stan without many sacrifices to flexibility.

I have one minor question: why does VonMises require Turing.@addlogprob! instead of ~? I want to make sure I know when to use Turing.@addlogprob!.

Edit: is it because theta is deterministic/ dependent on x and y?

sethaxen commented 3 years ago

I hope that one day one of the packages can approach the performance of Stan without many sacrifices to flexibility.

I haven't seen a direct comparison, so I actually don't know how Stan's AD compares to Julia's various AD packages.

I have one minor question: why does VonMises require Turing.@addlogprob! instead of ~? I want to make sure I know when to use Turing.@addlogprob!.

Edit: is it because theta is deterministic/ dependent on x and y?

More or less. Turing.@addlogprob! can be dangerous though unless you know what you are doing. In this case, it's fine because without that line, θ will be uniformly distributed in the interval [-π, π], and the von Mises density is with respect to the the uniform (Lebesgue) measure on some interval of width . But actually, I remembered that the VonMises logpdf implementation in Distributions is weird, because the support is on the interval [μ - π, μ + π], where μ is the mean angle, so it's better to roll your own logpdf whose support can be whichever interval of width you want:

mylogpdf(d::VonMises, x) = d.κ * (cos(x - d.μ) - 1) - log(d.I0κx) - log2π

@model model(data, S, n_sim) = begin
    x ~ Normal(0, 1)
    y ~ Normal(0, 1)
    θ = atan(y, x)
    Turing.@addlogprob! mylogpdf(VonMises(2, 0.25), θ)
    data ~ Quantum(θ, S, n_sim)
end

I strongly encourage you to check that this works correctly if you use it. e.g. by dropping the likelihood and drawing random samples from the prior, then draw random samples from VonMises, and verify that they follow the same distribution.

itsdfish commented 3 years ago

Thank you @sethaxen. I will be sure to validate your recommendations.

Regarding AD performance, I have done some comparisons with MCMCBenchmarks.jl. What I have found is that Julia ADs are orders of magnitude slower than Stan for realistic models (Unfortunately, I need to update and fix the package.) Nonetheless, here is an example based on this issue comparing Stan to ReverseDiff. Stan requires 8.7 seconds whereas ReverseDiff requires 501.4 seconds. Unfortunately, Zygote no longer works with this code, but what we found previously is that it is even slower than ReverseDiff. Here is the code.