TuringLang / Bijectors.jl

Implementation of normalising flows and constrained random variable transformations
https://turinglang.org/Bijectors.jl/
MIT License
199 stars 32 forks source link

Add `rng` to have the reproducibility in `PlanarLayer` and `RadialLayer` #244

Open xiaomingfu2013 opened 1 year ago

xiaomingfu2013 commented 1 year ago

Hi, thank you for the nice package! When I am looking into the PlanarLayer and RadialLayer, is it worth considering adding rng like:

function PlanarLayer(dims::Int, wrapper=identity, rng=Random.GLOBAL_RNG)
    w = wrapper(randn(rng, dims))
    u = wrapper(randn(rng, dims))
    b = wrapper(randn(rng, 1))
    return PlanarLayer(w, u, b)
end

function RadialLayer(dims::Int, wrapper=identity, rng=Random.GLOBAL_RNG)
    α_ = wrapper(randn(rng, 1))
    β = wrapper(randn(rng, 1))
    z_0 = wrapper(randn(rng, dims))
    return RadialLayer(α_, β, z_0)
end