Stackoverflow in custom bijector #292

Closed hanyas closed 8 months ago

hanyas commented 8 months ago

I am trying to define my own Tanh bijector

using Random
using Distributions
using LinearAlgebra

import Bijectors

struct Tanh <: Bijectors.Bijector end
(b::Tanh)(x::Real) = tanh(x)
(b::Tanh)(x) = map(b, x)

(ib::Bijectors.Inverse{<: Tanh})(y::Real) = atanh(y)
(ib::Bijectors.Inverse{<: Tanh})(y) = map(ib, y)

Bijectors.logabsdetjac(b::Tanh, x::Real) = 2.0 * (log(2.0) - x - softplus(-2.0 * x))
Bijectors.logabsdetjac(b::Tanh, x) = map(Bijectors.logabsdetjac, x)

dist = Distributions.MvNormal(zeros(1), I)
td = Bijectors.transformed(dist, Tanh())

y = rand(td)
Distributions.logpdf(td, y)

The forward transformation appears to be working, but I am struggling to understand the error I am receiving when computing the logpdf

ERROR: StackOverflowError:
     [1] with_logabsdet_jacobian(ib::Bijectors.Inverse{Tanh}, y::Vector{Float64})
       @ Bijectors ~/.julia/packages/Bijectors/QhObI/src/interface.jl:213
     [2] transform(t::Bijectors.Inverse{Tanh}, x::Vector{Float64})
       @ Bijectors ~/.julia/packages/Bijectors/QhObI/src/interface.jl:92
--- the last 2 lines are repeated 39990 more times ---
 [79983] with_logabsdet_jacobian(ib::Bijectors.Inverse{Tanh}, y::Vector{Float64})
       @ Bijectors ~/.julia/packages/Bijectors/QhObI/src/interface.jl:213
hanyas commented 8 months ago

I've also tried to define the bijector by following a similar recipe to that of the Scale bijector but without success

struct Tanh <: Bijector end

with_logabsdet_jacobian(b::Tanh, x) = transform(b, x), logabsdetjac(b, x)

transform(b::Tanh, x) = tanh(x)
transform(b::Tanh, x::AbstractVecOrMat) = tanh.(x)
transform(ib::Inverse{<:Tanh}, y) = transform(atanh, y)
transform(ib::Inverse{<:Tanh}, y::AbstractVecOrMat) = transform(@. atanh, y)

logabsdetjac(b::Tanh, x::Real) = _logabsdetjac_tanh(b, x, Val(0))
function logabsdetjac(b::Tanh, x::AbstractArray{<:Real,N}) where {N}
    return _logabsdetjac_tanh(b, x, Val(N))

_logabsdetjac_tanh(b::Tanh, x::Real, ::Val{0}) = 2.0 * (log(2.0) - x - softplus(-2.0 * x))
_logabsdetjac_tanh(b::Tanh, x::AbstractVector, ::Val{1}) = 2.0 * (log(2.0) - x - softplus(-2.0 * x)) * length(x)
torfjelde commented 8 months ago

The default implementation of with_logabsdet_jacobian for a Bijector is (transform(b, x), logabsdetjac(b, x)), but since you haven't defined logabsdetjac(::Inverse{Tanh}, y), you also hit the default impl of this, which is -logabsdetjac(inverse(b), inverse(b)(y)).

You then get a stack overlflow error because transform(::Inverse{Tanh}, y) is also not defined (Scale does not have an Inverse{<:Scale} implementation because its inverse is just inverting the scale factor and returning a new Scale).

In fact, here you don't really need to mess around with the Bijector stuff at all, since tanh is already a function so you don't need a "new" representation of it + its inverse atanh is similarly already defined.

I'd implement the above as:

using ChangesOfVariables, InverseFunctions, StatsFuns

InverseFunctions.inverse(::typeof(tanh)) = atanh
InverseFunctions.inverse(::typeof(atanh), x) = tanh

function ChangesOfVariables.with_logabsdet_jacobian(::typeof(tanh), x::Real)
    y = tanh(x)
    return y, _logabsdetjac_tanh(x)
function ChangesOfVariables.with_logabsdet_jacobian(::typeof(atanh), y::Real)
    x = atanh(y)
    return x, -_logabsdetjac_tanh(x)

# Use the irrational representation `StatsFuns.logtwo` to defer type-promotion.
# Similarly, I've removed all explicit usages of `Float64`, e.g. converted `2.0` to `2`
# to allow type-promotion to do its thing rather than forcing usage of `Float64`.
_logabsdetjac_tanh(x::Real) = 2 * (StatsFuns.logtwo - x - softplus(-2 * x))

If you want a version that is supposed to act elementwise, then you can use Bijectors.elementwise(f):

julia> using Bijectors

julia> elementwise(tanh)(rand(10))
10-element Vector{Float64}:

julia> with_logabsdet_jacobian(elementwise(tanh), rand(10))
([0.5475308984676883, 0.7498770212815672, 0.11406375475912378, 0.04598020777639154, 0.41278517115619784, 0.3067650082385844, 0.6441810700388316, 0.7430095366528289, 0.7023124306195118, 0.2806093226497268], -3.5844094465162772)