TuringLang / Turing.jl

Bayesian inference with probabilistic programming.
https://turinglang.org
MIT License
2.04k stars 219 forks source link

Recontruct vs Function with array of parameters #1969

Open deveshjawla opened 1 year ago

deveshjawla commented 1 year ago

There is a significant difference in the time it takes to perform inference using the reconstruct method below versus defining the function model i.e. the neural network using a function feedforward() as below. For a very simple and small problem like the IRIS dataset, the difference in efficiency when using the function method the inference happens in under one minute, and when using the reconstruct method then it takes 30 minutes. Any ideas why this happens and how to make the reconstruct method comparable?

@model function bayesnnMVG(x, y, μ_prior, σ_prior, reconstruct)
    θ ~ MvNormal(μ_prior, σ_prior)
    nn = reconstruct(θ)
    ŷ = nn(x)
    for i = 1:lastindex(y)
        y[i] ~ Categorical(ŷ[:, i])
    end
end

where parameters_initial, reconstruct = Flux.destructure(nn_initial)

As compared to the below

@model  function bayesnnMVG(x, y, μ_prior, σ_prior)
    θ ~ MvNormal(μ_prior, σ_prior)
    nn = feedforward(θ)

    ŷ = nn(x)
    for i = 1:lastindex(y)
        y[i] ~ Categorical(ŷ[:, i])
    end
end

where

function feedforward(θ::AbstractVector)
        W0 = reshape(θ[1:20], 5, 4)
       ...
        model = Chain(
            Dense(W0, b0, relu),
           ...
        return model
    end
Red-Portal commented 1 year ago

More information would be helpful: Are you using Zygote as the backend?

deveshjawla commented 1 year ago

More information would be helpful: Are you using Zygote as the backend?

Hi, Thanks for responding. For MCMC, ReverseDiff. for VI, ForwardDiff. What do you think?

Red-Portal commented 1 year ago

Can you try Zygote and see if the discrepancy is smaller? Flux tends to be optimized towards Zygote so..

deveshjawla commented 1 year ago

Can you try Zygote and see if the discrepancy is smaller? Flux tends to be optimized towards Zygote so..

No, the Zygote has an issue with the for loop in the @model and if I try to use LazyArray then it Errors for Categorical. MethodError: no method matching LazyArray(::Vector{Categorical{Float32, Vector{Float32}}})

torfjelde commented 1 year ago

There is a significant difference in the time it takes to perform inference using the reconstruct method below versus defining the function model i.e. the neural network using a function feedforward() as below.

I'd wager there's a type-instability introduced by Flux.destructure then, while in your manually implemented feedforward this is not the case.

You can check this by doing:

@code_warnype reconstruct(θ)
@code_warnype feedforward(θ)

Note that it might even be that reconstruct is type-stable, but that the resulting backwards-pass defined by ReverseDiff makes it unstable (there are ways to inspect this too, but for the moment check the above).

deveshjawla commented 1 year ago

I'd wager there's a type-instability introduced by Flux.destructure then, while in your manually implemented feedforward this is not the case.

When using restructure the Chain parameters are Float32 but the θ is Float64. I tried to explicitly define the Prior θ as Float32 Normal distributions, but it errored.

A snippet from @code_warntype for reconstruct is as follows, Although there are not red marked types in its output:

MethodInstance for (::Optimisers.Restructure{Chain{Tuple{Dense{typeof(relu), Matrix{Float32}, Vector{Float32}}, Dense{typeof(relu), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, typeof(softmax)}}, NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, Tuple{}}}}})(::ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}})

Similarly using @code_warntype in the case of using a feedforward(θ) similarly has no red marked types but the Chain is Float64:


From worker 3:      from feedforward(θ::AbstractVector) in Main at /Users/456828/Projects/Bayesian-Active-Learning/DataSets/coalminequakes_dataset/Network.jl:8
From worker 3:    Arguments
From worker 3:      #self#::Core.Const(feedforward)
From worker 3:      θ::ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}
From worker 3:    Locals
From worker 2:    │   %16 = Main.Dense(W0, b0, Main.relu)::Dense{typeof(relu), ReverseDiff.TrackedArray{Float64, Float64, 2, Matrix{Float64}, Matrix{Float64}}, Vector{ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 2, Matrix{Float64}, Matrix{Float64}}}}}```
devmotion commented 1 year ago

When using restructure the Chain parameters are Float32 but the θ is Float64.

The Float32 are a Flux-specific thing, they work mainly with Float32 and even started to enforce this more generally recently.