TuringLang / Bijectors.jl

Implementation of normalising flows and constrained random variable transformations
https://turinglang.org/Bijectors.jl/
MIT License
199 stars 32 forks source link

rational quadratic flows not supporting Float32 input #266

Closed zuhengxu closed 1 year ago

zuhengxu commented 1 year ago

Here is a minimal working example:

using Bijectors

K = 10; B = 2
Bijectors.RationalQuadraticSpline(randn(Float32,K), randn(Float32, K), randn(Float32, K-1), B)

which gives the error:


ERROR: MethodError: no method matching Bijectors.RationalQuadraticSpline(::Vector{Float64}, ::Vector{Float64}, ::Vector{Float32})

Closest candidates are:
  Bijectors.RationalQuadraticSpline(::T, ::T, ::T) where T<:(AbstractVector)
   @ Bijectors ~/.julia/packages/Bijectors/Fp0rB/src/bijectors/rational_quadratic_spline.jl:80
  Bijectors.RationalQuadraticSpline(::A, ::A, ::A, ::T2) where {T1, T2, A<:AbstractVector{T1}}
   @ Bijectors ~/.julia/packages/Bijectors/Fp0rB/src/bijectors/rational_quadratic_spline.jl:103

Stacktrace:
 [1] Bijectors.RationalQuadraticSpline(widths::Vector{Float32}, heights::Vector{Float32}, derivatives::Vector{Float32}, B::Int64)
   @ Bijectors ~/.julia/packages/Bijectors/Fp0rB/src/bijectors/rational_quadratic_spline.jl:109
 [2] top-level scope
   @ ~/Research/NF-playground/NF-examples/neural_spline_flow/nsf_layer.jl:55

This is because https://github.com/TuringLang/Bijectors.jl/blob/dd8a24b02feccc8f1bef180e7419fb03a61ba82f/src/bijectors/rational_quadratic_spline.jl#LL99C1-L107C4 (line 103--104) is not type stable---.- 0.5 returns Float64 instead Float32, while line 105 maintains Float32. This is probablematic because the three fields of RationalQuadraticSpline{T} has to be the same type.

Same issue applies to https://github.com/TuringLang/Bijectors.jl/blob/dd8a24b02feccc8f1bef180e7419fb03a61ba82f/src/bijectors/rational_quadratic_spline.jl#LL109C1-L123C4

torfjelde commented 1 year ago

Solved by #267