Closed DominiqueMakowski closed 2 days ago
and I just wanted to have some confirmation whether it was indeed bad practice to have data preparation steps inside the model
Yep! Data preparation steps should not be used in the model 👍
I was just curious why would that add so much overhead (as it does not seem to me related with the sampling bits)?
Arguably, one of the main selling points of Turing.jl is its generality: you can put whatever Julia code you want in the model and it should still be possible to run the model. However, to achieve this, we have to execute everything you put inside the @model function ...
body everytime we want to query the model. This in turn means that we don't do any "analysis" of the model, but instead just turn it all into a Julia function we can execute in different ways. As a result, stuff like data preparation put in the model itself is executed every time we want to sample from the model, we want to compute the log-joint probability of the model, etc.
With all this being said, it's very easy to keep things simple for your end-users without affecting performance:)
As you'll see if you define a @model
in the Julia REPL, it will say that two methods have been defined called, in your case, model_mixed
. One of these is the constructor of the model, which returns a DynamicPPL.Model
, which takes the arguments you provided (in this case, y
, x
, and participants
) and the other is modified version of the function you defined in @model
which take some internal arguments that we define.
So, if you want to do something like data preparation in the constructor, you can just remove that from the main @model
and then define an alternative constructor by hand, e.g.
@model function model_mixed(y, x, ppt_id, n)
# Priors
μ_intercept ~ Normal(0.3, 0.5)
μ_var1 ~ Normal(0, 0.5)
... # Rest of the model
end
# Define the alternative constructor which does data preparation:
function model_mixed(y, x, participants)
# Data preparation
ppt_id = [findfirst(ppt .== unique(participants)) for ppt in participants]
n = length(unique(ppt_id))
x = x .- mean(x)
return model_mixed(y, x, ppt_id, n)
end
The second method instance will never be called internally by Turing.jl (we just use the internal function generated by @model
), and so you won't encounter any perf issues.
One note of caution: make sure that you don't accidentally overwrite the original definition. For example, if n
was a kwarg, we'd have
@model function model_mixed(y, x, ppt_id; n = length(unique(ppt_id)))
# Priors
μ_intercept ~ Normal(0.3, 0.5)
μ_var1 ~ Normal(0, 0.5)
... # Rest of the model
end
# Define the alternative constructor which does data preparation:
function model_mixed(y, x, participants)
# Data preparation
ppt_id = [findfirst(ppt .== unique(participants)) for ppt in participants]
n = length(unique(ppt_id))
x = x .- mean(x)
return model_mixed(y, x, ppt_id; n=n)
end
In this case, we're redefining the 3-argument method instance of model_mixed
, thus overriding the definition from @model
. Calling the above would result in an infinite recursion 😬 To avoid this you can use type annotations for the arguments, or maybe just give the alternative constructor a slightly different name, e.g. model_mixed_with_data_prep
. Up to you 🤷
Thanks again for the detailed explanation! It seems like the proposed solution works 🥳
This is more of a question than an issue per se, but I would be tempted to put inside models data preparation steps (e.g., transforming random factors to integers, deriving polynomials from a variable, standardizing a variable etc.), so that the user doesn't have to worry about data preparation can can just input the "raw" variables from a dataframe into it and then the model would compute its necessary quantities.
In general, such data preparation is computationally inexpensive and should only happen once.
However, it seems from my trials and errors that this drastically slows down model sampling, and I just wanted to have some confirmation whether it was indeed bad practice to have data preparation steps inside the model and if so, I was just curious why would that add so much overhead (as it does not seem to me related with the sampling bits)? Thanks!