ReactiveBayes / RxInfer.jl

Julia package for automated Bayesian inference on a factor graph with reactive message passing
MIT License
259 stars 24 forks source link

The `inference` function implementation / Clean Code discussion #49

Closed bvdmitri closed 5 months ago

bvdmitri commented 1 year ago

Following the discussion on the seminar I invite everyone to comment on the current implementation of the inference function. The implementation can definitely be improved/shortened, argument checking code can be put in separate functions, etc. I invite everyone to comment if current implementation is readable at all:

https://github.com/biaslab/RxInfer.jl/blob/0bb7bcc75c39af4d8012f4ac4f044ac9e6c57da2/src/inference.jl#L429

bartvanerp commented 1 year ago

Just to kick off the discussion, how about something as follows:

function inference(;
    model::ModelGenerator,
    data,
    initmarginals = nothing,
    initmessages = nothing,
    constraints = nothing,
    meta = nothing,
    options = nothing,
    returnvars = nothing,
    iterations = nothing,
    free_energy = false,
    free_energy_diagnostics = BetheFreeEnergyDefaultChecks,
    showprogress = false,
    callbacks = nothing,
    addons = nothing,
    postprocess = DefaultPostprocess(),
    warn = true
)

    # create arguments structure for fixed options arguments
    args = InferenceArguments(
        initmarginals,
        initmessages,
        constraints,
        meta,
        options,
        returnvars,
        iterations,
        free_energy,
        free_energy_diagnostics,
        showprogress,
        callbacks,
        addons,
        postprocess,
        warn
    )

    # check whether data structure is valid
    __inference_check_dicttype(:data, data)

    # check whether arguments are valid
    check_arguments!(args)

    # create model
    inference_invoke_callback(callbacks, :before_model_creation)
    fmodel, freturnval = create_model(model, constraints = constraints, meta = meta, options = _options)
    inference_invoke_callback(callbacks, :after_model_creation, fmodel, freturnval)

    # process model and variables
    vardict, actors, updates = process_variables!(fmodel, getreturnvars(args), getwarn(args))

    # check iterations
    _iterations = check_iterations(iterations)

    try
        ...
    catch error
        __inference_process_error(error)
    end
end

function check_arguments!(args)

    # type checks for initmarginals, initmessages and callbacks
    __inference_check_dicttype(:initmarginals, initmarginals)
    __inference_check_dicttype(:initmessages, initmessages)
    __inference_check_dicttype(:callbacks, callbacks)

    # check callbacks
    check_callbacks(getcallbacks(args), getwarn(args))

    # check options
    check_options!(getoptions(args), getaddons(args))

end

function check_callbacks(callbacks, warn)

    # Check whether callbacks exist
    if warn && !isnothing(callbacks)
        for key in keys(callbacks)
            if key ∉ (
                :on_marginal_update,
                :before_model_creation,
                :after_model_creation,
                :before_inference,
                :before_iteration,
                :before_data_update,
                :after_data_update,
                :after_iteration,
                :after_inference
            )
                @warn "Unknown callback specification: $(key). Available callbacks: on_marginal_update, before_model_creation, after_model_creation, before_inference, before_iteration, before_data_update, after_data_update, after_iteration, after_inference. Set `warn = false` to supress this warning."
            end
        end
    end

end

function check_options!(options, addons)

    # convert structure
    options = convert(ModelInferenceOptions, options)

    # Override `options` addons if the `addons` keyword argument is present 
    if !isnothing(addons)
        if !isnothing(getaddons(options))
            @warn "Both `addons = ...` and `options = (addons = ..., )` specify a value for the `addons`. Ignoring the `options` setting. Set `warn = false` to supress this warning."
        end
        options = setaddons(options, addons)
    end

end

function process_variables!(model, returnvars, warn)

    # get dictionary of variables
    vardict = getvardict(model)

    # check return variables
    check_returnvars!(returnvars, vardict)

    __inference_check_dicttype(:returnvars, returnvars)

    # Use `__check_has_randomvar` to filter out unknown or non-random variables in the `returnvar` specification
    __check_has_randomvar(vardict, variable) = begin
        haskey_check   = haskey(vardict, variable)
        israndom_check = haskey_check ? israndom(vardict[variable]) : false
        if warn && !haskey_check
            @warn "`returnvars` object has `$(variable)` specification, but model has no variable named `$(variable)`. The `$(variable)` specification is ignored. Use `warn = false` to suppress this warning."
        elseif warn && haskey_check && !israndom_check
            @warn "`returnvars` object has `$(variable)` specification, but model has no **random** variable named `$(variable)`. The `$(variable)` specification is ignored. Use `warn = false` to suppress this warning."
        end
        return haskey_check && israndom_check
    end

    # Second, for each random variable entry we create an actor
    actors = Dict(variable => make_actor(vardict[variable], value) for (variable, value) in pairs(returnvars) if __check_has_randomvar(vardict, variable))

    # At third, for each random variable entry we create a boolean flag to track their updates
    updates = Dict(variable => MarginalHasBeenUpdated(false) for (variable, _) in pairs(actors))

    return vardict, actors, updates

end

function check_returnvars!(returnvars, vardict)
    # First what we do - we check if `returnvars` is nothing or one of the two possible values: `KeepEach` and `KeepLast`. 
    # If so, we replace it with either `KeepEach` or `KeepLast` for each random and not-proxied variable in a model
    if returnvars === nothing || returnvars === KeepEach() || returnvars === KeepLast()
        # Checks if the first argument is `nothing`, in which case returns the second argument
        returnoption = something(returnvars, iterations isa Number ? KeepEach() : KeepLast())
        returnvars   = Dict(variable => returnoption for (variable, value) in pairs(vardict) if (israndom(value) && !isanonymous(value)))
    end
end

function check_iterations(iterations)
    _iterations = something(iterations, 1)
    _iterations isa Integer || error("`iterations` argument must be of type Integer or `nothing`")
    _iterations > 0 || error("`iterations` arguments must be greater than zero")
    return _iterations
end

The main thing done here is that I broke the code up in some smaller functions. Note that this code probably will fail the tests because of the mutating functions, which might cause some errors (I did not check).

bvdmitri commented 1 year ago

Great comment @bartvanerp . Some changes will not work exactly, e.g. check_returnvars! function must return a modified version of the returnvars object, but its a very good start. We can re-use these functions for the rxinference function as well.

P.S. Maybe without comments like

# check callbacks
check_callbacks(getcallbacks(args), getwarn(args))

as those are not adding anything neither to the readability nor explaining anything complex.

bertdv commented 1 year ago

Thanks for the invitation. This is a good thing. I hope we get more reviews.

I like the proposal by Bart. Aside from details, it is the best next step to work on. In general, the current function is way too long (>1800 lines) and hard to follow. Bart's proposal makes the code better readable.

I see functions such as check_callbacks, process_variables, check_options etc. The function name inference does not follow the same styling convention. function infer or run_inference would be better.

I have some suggestions on data or data_stream. First, it is clear from the code that data streaming is an add-on that came after "batch processing". The first mentioning of "online-streaming" is at line 704 of the Inference function. The variable datastream is introduced at line 828. An application engineer will not find this. We need more emphasis on streaming. Batch processing is just special (simpler) case of streaming.

A second comment is about making data or datastream mandatory variables. Observations are just pointmass constraints on the variational posterior. One should be able to do inference in a model with only form and factorization constraints without any data. Constraining a posterior as a factorized Gaussian does make the posterior different from the prior even if there are no observations. Is it possible to in one of the next updates to treat data as constraints? That makes it more in line with the idea of CBFE minimization and should clean some code since data would not need special treatment.

MagnusKoudahl commented 1 year ago

A quick addon to @bertdv 's latest comment. When using RxInfer to implement AIF agents we often need to run without data whenever we do planning. Currently this is most easily done by treating some variables as "datavars" to get inference started - even though they're not really data points which is clunky. Being able to run inference with only constraints would make this part of the process a lot easier

albertpod commented 6 months ago

Can we close this one?

bvdmitri commented 5 months ago

Indeed, I closed because its not an issue by itself, but rather a discussion on better code without a clear plan. We can keep discussion internally or open a new one with a clear plan.