TuringLang / Bijectors.jl

Implementation of normalising flows and constrained random variable transformations
https://turinglang.org/Bijectors.jl/
MIT License
200 stars 33 forks source link

Zygote AD & `logpdf` for transformed multivariate #217

Open tpgillam opened 2 years ago

tpgillam commented 2 years ago

I've found that Zygote fails to compute gradients when using the method of logpdf defined here

Here's a MWE:

using Bijectors
using DistributionsAD
using Flux
using Zygote

d = MvNormal(zeros(2), ones(2))
b = PlanarLayer(2)
flow = transformed(d, b)

x = [0.42 0.24; 0.42 0.24]

"""Use the optimised `logpdf` call."""
loss_(flow, x) = -sum(logpdf(flow, x))

"""Rearrange to use default `logpdf` in `Distributions`."""
function loss2_(flow, x)
    things = map(eachcol(x)) do obs
        logpdf(flow, obs)
    end
    return -sum(things)
end

@show loss_(flow, x)
@show loss2_(flow, x)

println()

gs = gradient(() -> loss_(flow, x), Flux.params(b))
@show gs.grads[Flux.params(b)[1]]

gs = gradient(() -> loss2_(flow, x), Flux.params(b))
@show gs.grads[Flux.params(b)[1]];

With output:

loss_(flow, x) = 3.089176357252711
loss2_(flow, x) = 3.089176357252711

gs.grads[(Flux.params(b))[1]] = nothing
gs.grads[(Flux.params(b))[1]] = [-2.603210756288831, -4.3264084139896095]

tested on Bijectors v0.10.0.

I'm not sure, but maybe the optimised dispatch for logpdf (or some of the methods called within) need additional chainrules support?