TuringLang / Bijectors.jl

Implementation of normalising flows and constrained random variable transformations
https://turinglang.org/Bijectors.jl/
MIT License
200 stars 33 forks source link

Fitting Normalizing Flows with Flux #175

Closed rcnlee closed 3 years ago

rcnlee commented 3 years ago

Hello, I'm trying to follow the normalizing flows example on the wiki using Flux. However, I'm not able to use Flux.params as the parameter wrapper:

using Turing, Bijectors, Flux
b = PlanarLayer(2, Flux.params)

MethodError: no method matching PlanarLayer(::Params, ::Params, ::Params)
Closest candidates are:
  PlanarLayer(::Int64, ::Any) at /Users/rlee18/.julia/packages/Bijectors/OJrCc/src/bijectors/planar_layer.jl:26

Stacktrace:
 [1] PlanarLayer(dims::Int64, wrapper::typeof(params))
   @ Bijectors ~/.julia/packages/Bijectors/OJrCc/src/bijectors/planar_layer.jl:30
 [2] top-level scope
   @ In[35]:8
 [3] eval
   @ ./boot.jl:360 [inlined]
 [4] include_string(mapexpr::typeof(REPL.softscope), mod::Module, code::String, filename::String)
   @ Base ./loading.jl:1094

I'm also having trouble getting it to work with Tracker as well:

using Turing, Bijectors, Flux, Tracker
b = PlanarLayer(2, param)
Flux.params(b)
> Params([])

Here is another example I tried:

using Turing, Bijectors, Flux, Tracker
d = MvNormal(zeros(2), ones(2));
b = PlanarLayer(2, param)
td = transformed(d, b)
Flux.params(td)
> Params([])

Thanks!

torfjelde commented 3 years ago

So there's a difference between Flux.params and Tracker.param:

So the following is probably what you want to do:

julia> using Bijectors, Tracker

julia> b = PlanarLayer(2, Tracker.param)
PlanarLayer{TrackedArray{…,Vector{Float64}}, TrackedArray{…,Vector{Float64}}}([0.7640542684637393, 0.027394069286570465] (tracked), [0.6823998572495387, 1.2829720031601226] (tracked), [-0.2763080665702938] (tracked))

julia> x = randn(2)
2-element Vector{Float64}:
 -0.6370643499661355
 -0.8888414209863912

julia> b(x)
Tracked 2-element Vector{Float64}:
 -0.6157400027627787
 -1.7148377496240932
rcnlee commented 3 years ago

Thanks for the reply @torfjelde. Any idea why my last example doesn't work? It is taken directly from the README: "In those cases, it might be useful to use Flux.jl's Flux.params to extract the parameters. Thanks!

torfjelde commented 3 years ago

What version of Bijectors.jl are you on? Because that works on my end:

julia> b = PlanarLayer(2, Tracker.param)
PlanarLayer{TrackedArray{…,Vector{Float64}}, TrackedArray{…,Vector{Float64}}}([0.7640542684637393, 0.027394069286570465] (tracked), [0.6823998572495387, 1.2829720031601226] (tracked), [-0.2763080665702938] (tracked))

julia> Flux.params(b)
Params([[0.7640542684637393, 0.027394069286570465] (tracked), [0.6823998572495387, 1.2829720031601226] (tracked), [-0.2763080665702938] (tracked)])

julia> b = PlanarLayer(2)
PlanarLayer{Vector{Float64}, Vector{Float64}}([0.8750845996907893, -0.058422540801975766], [0.008596254577741316, -0.36822437116770174], [1.731448937866169])

julia> Flux.params(b)
Params([[0.8750845996907893, -0.058422540801975766], [0.008596254577741316, -0.36822437116770174], [1.731448937866169]])
rcnlee commented 3 years ago

It doesn't work in this configuration:

I seem to be stuck at Bijectors 0.8.13 for some reason. Calling update isn't getting me v0.9.0.

torfjelde commented 3 years ago

Ah, so Flux 0.12. uses Functors 0.2; Bijectors.jl is compatible only with Functors 0.1 at the moment. I'll see what I can do. But in the meantime you need to upgrade Bijectors, i.e. downgrade Flux to 0.11.6.

rcnlee commented 3 years ago

Ah got it, thanks @torfjelde!

torfjelde commented 3 years ago

Np:) Btw, when https://github.com/TuringLang/Bijectors.jl/pull/166 is merged, Bijectors.jl should be usable with Flux 0.12:)