TuringLang / Bijectors.jl

Implementation of normalising flows and constrained random variable transformations
https://turinglang.org/Bijectors.jl/
MIT License
200 stars 33 forks source link

softplus bijector? #132

Closed willtebbutt closed 1 year ago

willtebbutt commented 4 years ago

If someone has some time, a softplus Bijector would be cool to have.

torfjelde commented 4 years ago

I'm currently not able to try this out, but the following should do the trick:

using StatsFuns

struct Softplus <: Bijector{0} end

# Forward
(::Softplus)(x::Real) = softplus(x)
(::Softplus)(x::AbstractArray{<:Real}) = softplus.(x)

# Backward
(::Inversed{<:Softplus})(y::Real) = invsoftplus(y)
(::Inversed{<:Softplus})(y::AbstractArray{<:Real}) = invsoftplus.(y)

# logabsdetjac (forward)
logabsdetjac(b::Softplus, x::Real) = x - log(1 + exp(x))  # I THINK this is right, haven't written it down

I'll make a PR and such when I'm back at the desk :+1:

Sidenote: I think I realized a way we can avoid this code-duplication for making things work when we look at "batches" (i.e. define the (::Bijector{0})(x::AbstractArray{<:Real}) once and have all <:Bijector{0} inherit this 🎉

devmotion commented 4 years ago

Sidenote: I think I realized a way we can avoid this code-duplication for making things work when we look at "batches" (i.e. define the (::Bijector{0})(x::AbstractArray{<:Real}) once and have all <:Bijector{0} inherit this tada

An alternative (as done, e.g., by Distributions.logpdf) would be to demand from all users to specify broadcasting explicitly. I.e., users should just call Softplus().(x) where x is a vector.

devmotion commented 4 years ago

logabsdetjac(b::Softplus, x::Real) = x - log(1 + exp(x)) # I THINK this is right, haven't written it down

Assuming it's correct (haven't thought about it), one should probably implement it as

logabsdetjac(::Softplus, x::Real) = -log1pexp(-x)
sethaxen commented 1 year ago

This is provided now by LogExpFunctions's ChangesOfVariables extension: https://github.com/JuliaStats/LogExpFunctions.jl/blob/a1c4fda2b9cc4c59c184648c0cfc7f694c415bf3/ext/LogExpFunctionsChangesOfVariablesExt.jl#L7-L10