Closed bvdmitri closed 5 months ago
Just to kick off the discussion, how about something as follows:
function inference(;
model::ModelGenerator,
data,
initmarginals = nothing,
initmessages = nothing,
constraints = nothing,
meta = nothing,
options = nothing,
returnvars = nothing,
iterations = nothing,
free_energy = false,
free_energy_diagnostics = BetheFreeEnergyDefaultChecks,
showprogress = false,
callbacks = nothing,
addons = nothing,
postprocess = DefaultPostprocess(),
warn = true
)
# create arguments structure for fixed options arguments
args = InferenceArguments(
initmarginals,
initmessages,
constraints,
meta,
options,
returnvars,
iterations,
free_energy,
free_energy_diagnostics,
showprogress,
callbacks,
addons,
postprocess,
warn
)
# check whether data structure is valid
__inference_check_dicttype(:data, data)
# check whether arguments are valid
check_arguments!(args)
# create model
inference_invoke_callback(callbacks, :before_model_creation)
fmodel, freturnval = create_model(model, constraints = constraints, meta = meta, options = _options)
inference_invoke_callback(callbacks, :after_model_creation, fmodel, freturnval)
# process model and variables
vardict, actors, updates = process_variables!(fmodel, getreturnvars(args), getwarn(args))
# check iterations
_iterations = check_iterations(iterations)
try
...
catch error
__inference_process_error(error)
end
end
function check_arguments!(args)
# type checks for initmarginals, initmessages and callbacks
__inference_check_dicttype(:initmarginals, initmarginals)
__inference_check_dicttype(:initmessages, initmessages)
__inference_check_dicttype(:callbacks, callbacks)
# check callbacks
check_callbacks(getcallbacks(args), getwarn(args))
# check options
check_options!(getoptions(args), getaddons(args))
end
function check_callbacks(callbacks, warn)
# Check whether callbacks exist
if warn && !isnothing(callbacks)
for key in keys(callbacks)
if key ∉ (
:on_marginal_update,
:before_model_creation,
:after_model_creation,
:before_inference,
:before_iteration,
:before_data_update,
:after_data_update,
:after_iteration,
:after_inference
)
@warn "Unknown callback specification: $(key). Available callbacks: on_marginal_update, before_model_creation, after_model_creation, before_inference, before_iteration, before_data_update, after_data_update, after_iteration, after_inference. Set `warn = false` to supress this warning."
end
end
end
end
function check_options!(options, addons)
# convert structure
options = convert(ModelInferenceOptions, options)
# Override `options` addons if the `addons` keyword argument is present
if !isnothing(addons)
if !isnothing(getaddons(options))
@warn "Both `addons = ...` and `options = (addons = ..., )` specify a value for the `addons`. Ignoring the `options` setting. Set `warn = false` to supress this warning."
end
options = setaddons(options, addons)
end
end
function process_variables!(model, returnvars, warn)
# get dictionary of variables
vardict = getvardict(model)
# check return variables
check_returnvars!(returnvars, vardict)
__inference_check_dicttype(:returnvars, returnvars)
# Use `__check_has_randomvar` to filter out unknown or non-random variables in the `returnvar` specification
__check_has_randomvar(vardict, variable) = begin
haskey_check = haskey(vardict, variable)
israndom_check = haskey_check ? israndom(vardict[variable]) : false
if warn && !haskey_check
@warn "`returnvars` object has `$(variable)` specification, but model has no variable named `$(variable)`. The `$(variable)` specification is ignored. Use `warn = false` to suppress this warning."
elseif warn && haskey_check && !israndom_check
@warn "`returnvars` object has `$(variable)` specification, but model has no **random** variable named `$(variable)`. The `$(variable)` specification is ignored. Use `warn = false` to suppress this warning."
end
return haskey_check && israndom_check
end
# Second, for each random variable entry we create an actor
actors = Dict(variable => make_actor(vardict[variable], value) for (variable, value) in pairs(returnvars) if __check_has_randomvar(vardict, variable))
# At third, for each random variable entry we create a boolean flag to track their updates
updates = Dict(variable => MarginalHasBeenUpdated(false) for (variable, _) in pairs(actors))
return vardict, actors, updates
end
function check_returnvars!(returnvars, vardict)
# First what we do - we check if `returnvars` is nothing or one of the two possible values: `KeepEach` and `KeepLast`.
# If so, we replace it with either `KeepEach` or `KeepLast` for each random and not-proxied variable in a model
if returnvars === nothing || returnvars === KeepEach() || returnvars === KeepLast()
# Checks if the first argument is `nothing`, in which case returns the second argument
returnoption = something(returnvars, iterations isa Number ? KeepEach() : KeepLast())
returnvars = Dict(variable => returnoption for (variable, value) in pairs(vardict) if (israndom(value) && !isanonymous(value)))
end
end
function check_iterations(iterations)
_iterations = something(iterations, 1)
_iterations isa Integer || error("`iterations` argument must be of type Integer or `nothing`")
_iterations > 0 || error("`iterations` arguments must be greater than zero")
return _iterations
end
The main thing done here is that I broke the code up in some smaller functions. Note that this code probably will fail the tests because of the mutating functions, which might cause some errors (I did not check).
Great comment @bartvanerp . Some changes will not work exactly, e.g. check_returnvars!
function must return a modified version of the returnvars
object, but its a very good start. We can re-use these functions for the rxinference
function as well.
P.S. Maybe without comments like
# check callbacks
check_callbacks(getcallbacks(args), getwarn(args))
as those are not adding anything neither to the readability nor explaining anything complex.
Thanks for the invitation. This is a good thing. I hope we get more reviews.
I like the proposal by Bart. Aside from details, it is the best next step to work on. In general, the current function is way too long (>1800 lines) and hard to follow. Bart's proposal makes the code better readable.
I see functions such as check_callbacks
, process_variables
, check_options
etc. The function name inference
does not follow the same styling convention. function infer
or run_inference
would be better.
I have some suggestions on data
or data_stream
. First, it is clear from the code that data streaming is an add-on that came after "batch processing". The first mentioning of "online-streaming" is at line 704 of the Inference
function. The variable datastream
is introduced at line 828. An application engineer will not find this. We need more emphasis on streaming. Batch processing is just special (simpler) case of streaming.
A second comment is about making data
or datastream
mandatory variables. Observations are just pointmass constraints on the variational posterior. One should be able to do inference in a model with only form and factorization constraints without any data. Constraining a posterior as a factorized Gaussian does make the posterior different from the prior even if there are no observations. Is it possible to in one of the next updates to treat data as constraints? That makes it more in line with the idea of CBFE minimization and should clean some code since data would not need special treatment.
A quick addon to @bertdv 's latest comment. When using RxInfer to implement AIF agents we often need to run without data whenever we do planning. Currently this is most easily done by treating some variables as "datavars" to get inference started - even though they're not really data points which is clunky. Being able to run inference with only constraints would make this part of the process a lot easier
Can we close this one?
Indeed, I closed because its not an issue by itself, but rather a discussion on better code without a clear plan. We can keep discussion internally or open a new one with a clear plan.
Following the discussion on the seminar I invite everyone to comment on the current implementation of the
inference
function. The implementation can definitely be improved/shortened, argument checking code can be put in separate functions, etc. I invite everyone to comment if current implementation is readable at all:https://github.com/biaslab/RxInfer.jl/blob/0bb7bcc75c39af4d8012f4ac4f044ac9e6c57da2/src/inference.jl#L429