TuringLang / Bijectors.jl

Implementation of normalising flows and constrained random variable transformations
https://turinglang.org/Bijectors.jl/
MIT License
199 stars 32 forks source link

Use AbstractVector in LKJ and LKJCholesky bijectors #253

Closed harisorgn closed 1 year ago

harisorgn commented 1 year ago

Expands on #246 .

Use ::AbstractVector in VecCorrBijector operations, so we won't need to transform to ::AbstractMatrix and back.

Add bijector for LKJCholesky. I believe this was missing and in practice it is the more efficient alternative when working with correlation matrices (avoids Cholesky decompositions on every call). In LKJCholesky there is control over the returned factor ('U' -> UpperTriangular or 'L' -> LowerTriangular). I was wondering whether we want to respect the factor choice and always return the same triangular factor. If yes, we can use VecTriuBijector and VecTrilBijector to retain information about the original factor in LKJCholesky and return it. If no, we can always work with one type, e.g. UpperTriangular.

TO DO :

Related to the second point, right above, in general it would be nice if we could test these analytical formulas for logabsdetjac derived by hand. I played around with it a bit, but couldn't come up with something. EDIT : This can be done using AD. I see there is something already implemented along these lines in test/transform.jl, just needs some tweaking.

cc @torfjelde if you want to have a look already

harisorgn commented 1 year ago

It overall looks great:) But I think maybe we should just merge the VecTriuBijector and VecTrilBijector (+ their abstract type) into a single VecCorrBijector. This way we avoid type-instabilities in bijector + we reduce the number of types.

Thanks for all the feedback @torfjelde , agreed with your suggestions : )

Initially in 0c3aa39 I added a parametric Val type on VecCorrBijector so that the transform would be type-stable. But this made bijector(dist) unstable (VecCorrBijector was being inferred without the parametric Val). Eventually switch to using a field in VecCorrBijector (called it mode instead of uplo because it can have a third value 'C') and an if statement. Now the inferred type of transform is a Union of two, which should be fine with union-splitting.

Have another look and If you think all issues were addressed, we can merge.

harisorgn commented 1 year ago

Have another look and If you think all issues were addressed, we can merge.

Actually let's wait, as the AD tests I've just added on the roundtrip transformation fail, will have a look.

harisorgn commented 1 year ago

Looked more into test_ad using the roundtrip inverse-and-then-forward transformation :

Test is failing on cholesky(Matrix{ForwardDiff.Dual...}) . I see there is no frule defined for cholesky, not sure if something ForwardDiff specific exists elsewhere. EDIT: ForwardDiff is not officially using ChainRules to define rules, but also could not find a forward rule for cholesky in ForwardDiffChainRules.

Unsure about the Tracker error. Is the plan to keep supporting Tracker in general?

Zygote is returning nothing gradients, that's its "hard zero" IIUC. Not sure if it has to do with my usage of getproperty on Cholesky, UpperTriangular and LowerTriangular types.

ReverseDiff is passing all tests after fixing the rule for pd_from_upper.

All these tests are for AD through transform which IIUC are not relevant for Turing.jl usage. Since this PR is adding on the previous one that aims to get LKJ priors working for Turing, it might be worth merging and tackling the remaining AD issues in a future PR?

torfjelde commented 1 year ago

Regarding the test-failures, it's a bit strange.

This seems related: https://github.com/JuliaDiff/ForwardDiff.jl/issues/606.

But I thought this was fixed because they pulled 0.10.33 after the discussion there, and deferred the breaking changes to 0.11. The tests are running on 0.10.35 so I don't get why we're seeing this :confused:

Think this needs a bit further inspection.

And we do actually need to AD through the transform in some places, e.g. ADVI.

(Btw, I'm not done with my review, will continue later)

harisorgn commented 1 year ago

Good points @torfjelde , thanks! Here's more about the AD issues :

Tracker

From discussions elsewhere (Slack) I understand that we agree to drop support for this.

ForwardDiff

This seems related: https://github.com/JuliaDiff/ForwardDiff.jl/issues/606.

It might actually not be. It seems like a numerical issue when comparing values in ishermitian.

I found two samples from the same LKJ where one passes and one fails. MWE :

using Bijectors, DistributionsAD, LinearAlgebra
using Bijectors: VecCorrBijector
using ForwardDiff
using ForwardDiff: Dual

b = VecCorrBijector('C') # bijector(LKJ(5,1))
binv = inverse(b)

f = x -> sum(b(binv(x)))

# x_f ~ LKJ(5,1)
x_f = [
    1.0  0.38808945715615550398  0.55251148082365042491   0.06333711952583508109  -0.51630779311225594164
    0.38808945715615550398  1.0  0.31760367441586356829   0.34585990227668395036   0.06051504059466897290
    0.55251148082365042491  0.31760367441586356829   1.0   0.17416714618194936715  -0.02825518349677474950
    0.06333711952583508109  0.34585990227668395036   0.17416714618194936715   1.0  -0.07513830680477201485
    -0.51630779311225594164  0.06051504059466897290  -0.02825518349677474950   -0.07513830680477201485   1.0
]
df_f = ForwardDiff.gradient(f, b(x_f)) # Errors, ishermitian returns false

# x_s ~ LKJ(5,1)
x_s = [
    1.0  -0.01569213125090618277 -0.79039374741027101923  -0.03400980954333766848   0.54371128016847525277
    -0.01569213125090618277   1.0 -0.19877390203937703173   -0.37124942960738860354  -0.39209191569764001439
    -0.79039374741027101923  -0.19877390203937703173   1.0   0.03430683023840974677  -0.62744676631878926187
    -0.03400980954333766848  -0.37124942960738860354   0.03430683023840974677   1.0   0.50841756191547016197
    0.54371128016847525277  -0.39209191569764001439    -0.62744676631878926187  0.50841756191547016197   1.0
]
df_s = ForwardDiff.gradient(f, b(x_s)) # Runs, ishermitian returns true

# Let's see where x_f fails
function ish(A::AbstractMatrix)
    # Just a copy of ishermitian with a `@show`
    indsm, indsn = axes(A)
    if indsm != indsn
        return false
    end
    for i = indsn, j = i:last(indsn)
        if A[i,j] != adjoint(A[j,i])
            @show abs(A[i,j] - adjoint(A[j,i]))
            return false
        end
   end
    return true
end

y_f = b(x_f)
ish(binv(Dual.(y_f))) # Returns false, shows abs(A[i, j] - adjoint(A[j, i])) = Dual{Nothing}(2.0816681711721685e-17)

# Without using `Dual`s though, all is good
ish(binv(y_f)) # Returns true

So ishermitian fails because of a very small difference between a single pair of adjoint elements. This is consistent across other samples from LKJ(5,1). Shall we remove the ishermitian check altogether? Not sure how safe that is, but by trying out this uniform LKJ(5,1) over correlation matrices, all I get is tiny errors with Duals like in the example.

EDIT: Tried using cholesky(x; check = false) but the gradients for these problematic samples are way off (1e-1), even if the matrices are not hermitian by very little (1e-17).

Zygote

This indeed has to do with getproperty(::Cholesky, :UL). In ChainRules there is an rrule defined for getproperty(::Cholesky, ::Symbol) that only accounts for the cases of :U and :L. So we have:

using Bijectors, DistributionsAD, LinearAlgebra
using Zygote

dist = LKJ(5, 1)
x = rand(dist)

g = x -> sum(cholesky(x).U)
dg = Zygote.gradient(g, x) # Returns correct gradient

h = x -> sum(cholesky(x).UL)
dh = Zygote.gradient(h, x) # Returns (nothing, )

So Zygote can be fixed by changing https://github.com/TuringLang/Bijectors.jl/blob/0d599e858131be8e9ae7af289b16f94112f502e7/src/utils.jl#L18 to X.U , take the potential extra allocation (if uplo === :L) and always work with UpperTriangular downstream. Using PDMats.chol_upper as suggested here results in the same issue by accessing getproperty(::Cholesky, :factors).

Any thoughts on how to handle the ForwardDiff and Zygote cases? I think the Zygote changes are more straightforward unless I'm missing something.

harisorgn commented 1 year ago

I think the Zygote changes are more straightforward unless I'm missing something.

It is for the case of LKJ (changing X.UL to X.U works) but not for LKJCholesky. In the latter case, we construct a Cholesky during the inverse transform, as this is the support of the distribution. I'm guessing the Cholesky constructor needs an rrule for Zygote to work.

harisorgn commented 1 year ago

Zygote is fixed. It was more straightforward than writing new rrules, just passing a X::Matrix instead of X::UpperTriangular or X::LowerTriangular to Cholesky and avoid doing X.data.

ForwardDiff passed twice on the latest commit, but I changed nothing to fix it. Probably has to do with the stochastic nature of the numerical error, like the example above.

harisorgn commented 1 year ago

I restarted the Inference tests multiple times and the ForwardDiff test passes (only fails are from Tracker not being broken). I can't recreate this locally, I still get some fails and passes like the example above, and have matched package versions, so I'm confused 😅

harisorgn commented 1 year ago

@torfjelde , I implemented your suggestions, thanks for the feedback again : )

I couldn't locally reproduce the DomainError that comes up in the AD test.

Also disregard my previous confusion about reproducing the ForwardDiff numerical error. I was restarting an interface test that wasn't hitting it, hence it was passing. When the right interface test of CI was run, test failed as it fails locally (see comments above).

So there are still these two errors, plus the stack related one that is addressed in another PR here.

(Apologies for the format, only have phone access for now)

harisorgn commented 1 year ago

@torfjelde accidental merge, sorry, was setting up git in a new machine 😅 . Please revert it and I'll implement the last changes.

torfjelde commented 1 year ago

Is it maybe easier if you just take over the other PR?:) #246