Closed willtebbutt closed 1 year ago
I'm currently not able to try this out, but the following should do the trick:
using StatsFuns
struct Softplus <: Bijector{0} end
# Forward
(::Softplus)(x::Real) = softplus(x)
(::Softplus)(x::AbstractArray{<:Real}) = softplus.(x)
# Backward
(::Inversed{<:Softplus})(y::Real) = invsoftplus(y)
(::Inversed{<:Softplus})(y::AbstractArray{<:Real}) = invsoftplus.(y)
# logabsdetjac (forward)
logabsdetjac(b::Softplus, x::Real) = x - log(1 + exp(x)) # I THINK this is right, haven't written it down
I'll make a PR and such when I'm back at the desk :+1:
Sidenote: I think I realized a way we can avoid this code-duplication for making things work when we look at "batches" (i.e. define the (::Bijector{0})(x::AbstractArray{<:Real})
once and have all <:Bijector{0}
inherit this 🎉
Sidenote: I think I realized a way we can avoid this code-duplication for making things work when we look at "batches" (i.e. define the (::Bijector{0})(x::AbstractArray{<:Real}) once and have all <:Bijector{0} inherit this tada
An alternative (as done, e.g., by Distributions.logpdf
) would be to demand from all users to specify broadcasting explicitly. I.e., users should just call Softplus().(x)
where x
is a vector.
logabsdetjac(b::Softplus, x::Real) = x - log(1 + exp(x)) # I THINK this is right, haven't written it down
Assuming it's correct (haven't thought about it), one should probably implement it as
logabsdetjac(::Softplus, x::Real) = -log1pexp(-x)
This is provided now by LogExpFunctions's ChangesOfVariables extension: https://github.com/JuliaStats/LogExpFunctions.jl/blob/a1c4fda2b9cc4c59c184648c0cfc7f694c415bf3/ext/LogExpFunctionsChangesOfVariablesExt.jl#L7-L10
If someone has some time, a
softplus
Bijector would be cool to have.