Closed hanyas closed 8 months ago
I've also tried to define the bijector by following a similar recipe to that of the Scale
bijector but without success
struct Tanh <: Bijector end
with_logabsdet_jacobian(b::Tanh, x) = transform(b, x), logabsdetjac(b, x)
transform(b::Tanh, x) = tanh(x)
transform(b::Tanh, x::AbstractVecOrMat) = tanh.(x)
transform(ib::Inverse{<:Tanh}, y) = transform(atanh, y)
transform(ib::Inverse{<:Tanh}, y::AbstractVecOrMat) = transform(@. atanh, y)
logabsdetjac(b::Tanh, x::Real) = _logabsdetjac_tanh(b, x, Val(0))
function logabsdetjac(b::Tanh, x::AbstractArray{<:Real,N}) where {N}
return _logabsdetjac_tanh(b, x, Val(N))
end
_logabsdetjac_tanh(b::Tanh, x::Real, ::Val{0}) = 2.0 * (log(2.0) - x - softplus(-2.0 * x))
_logabsdetjac_tanh(b::Tanh, x::AbstractVector, ::Val{1}) = 2.0 * (log(2.0) - x - softplus(-2.0 * x)) * length(x)
The default implementation of with_logabsdet_jacobian
for a Bijector
is (transform(b, x), logabsdetjac(b, x))
, but since you haven't defined logabsdetjac(::Inverse{Tanh}, y)
, you also hit the default impl of this, which is -logabsdetjac(inverse(b), inverse(b)(y))
.
You then get a stack overlflow error because transform(::Inverse{Tanh}, y)
is also not defined (Scale
does not have an Inverse{<:Scale}
implementation because its inverse is just inverting the scale factor and returning a new Scale
).
In fact, here you don't really need to mess around with the Bijector
stuff at all, since tanh
is already a function so you don't need a "new" representation of it + its inverse atanh
is similarly already defined.
I'd implement the above as:
using ChangesOfVariables, InverseFunctions, StatsFuns
InverseFunctions.inverse(::typeof(tanh)) = atanh
InverseFunctions.inverse(::typeof(atanh), x) = tanh
function ChangesOfVariables.with_logabsdet_jacobian(::typeof(tanh), x::Real)
y = tanh(x)
return y, _logabsdetjac_tanh(x)
end
function ChangesOfVariables.with_logabsdet_jacobian(::typeof(atanh), y::Real)
x = atanh(y)
return x, -_logabsdetjac_tanh(x)
end
# Use the irrational representation `StatsFuns.logtwo` to defer type-promotion.
# Similarly, I've removed all explicit usages of `Float64`, e.g. converted `2.0` to `2`
# to allow type-promotion to do its thing rather than forcing usage of `Float64`.
_logabsdetjac_tanh(x::Real) = 2 * (StatsFuns.logtwo - x - softplus(-2 * x))
If you want a version that is supposed to act elementwise, then you can use Bijectors.elementwise(f)
:
julia> using Bijectors
julia> elementwise(tanh)(rand(10))
10-element Vector{Float64}:
0.22076308094447367
0.06828859488600718
0.3496810171644955
0.02413051400382789
0.6228303792319176
0.5772825278828461
0.7370222452215927
0.45865543543291265
0.6128386429868988
0.7094298145373448
julia> with_logabsdet_jacobian(elementwise(tanh), rand(10))
([0.5475308984676883, 0.7498770212815672, 0.11406375475912378, 0.04598020777639154, 0.41278517115619784, 0.3067650082385844, 0.6441810700388316, 0.7430095366528289, 0.7023124306195118, 0.2806093226497268], -3.5844094465162772)
I am trying to define my own
Tanh
bijectorThe forward transformation appears to be working, but I am struggling to understand the error I am receiving when computing the
logpdf