cscherrer / SossMLJ.jl

SossMLJ makes it easy to build MLJ machines from user-defined models from the Soss probabilistic programming language
https://cscherrer.github.io/SossMLJ.jl/stable/
MIT License
15 stars 1 forks source link

We need to fix the implementation of `MLJModelInterface.predict` for classifiers #92

Open DilumAluthge opened 4 years ago

DilumAluthge commented 4 years ago

In #91, I added an incorrect implementation of MMI.predict for classifiers. This allowed me to finish the pipeline, do cross validation, add additional tests, etc.

But we should fix the implementation of MMI.predict before we register the package.

Basically, we need a method MMI.predict that outputs a Vector{UnivariateFinite}.

cscherrer commented 4 years ago

Just some notes here as I think through this...

predict_joint gives us a SossMLJPredictor. For the multinomial example, the fields have types

julia> [p => typeof(getproperty(predictor_joint, p)) for p in propertynames(predictor_joint)]
4-element Array{Pair{Symbol,DataType},1}:
 :model => SossMLJModel{UnivariateFinite,Soss.Model{NamedTuple{(:X, :pool),T} where T<:Tuple,TypeEncoding(begin
    k = length(pool.levels)
    p = size(X, 2)
    β ~ Normal(0.0, 1.0) |> iid(p, k)
    η = X * β
    μ = NNlib.softmax(η; dims = 2)
    y_dists = UnivariateFinite(pool.levels, μ; pool = pool)
    n = size(X, 1)
    y ~ For((j->begin
                    y_dists[j]
                end), n)
end),TypeEncoding(Main)},NamedTuple{(:pool,),Tuple{CategoricalPool{String,UInt8,CategoricalValue{String,UInt8}}}},typeof(dynamicHMC),Symbol,typeof(SossMLJ.default_transform)}
  :post => Array{NamedTuple{(:β,),Tuple{Array{Float64,2}}},1}
  :pred => Soss.Model{NamedTuple{(:X, :pool, :β),T} where T<:Tuple,TypeEncoding(begin
    η = X * β
    μ = NNlib.softmax(η; dims = 2)
    y_dists = UnivariateFinite(pool.levels, μ; pool = pool)
    n = size(X, 1)
    y ~ For((j->begin
                    y_dists[j]
                end), n)
end),TypeEncoding(Main)}
  :args => NamedTuple{(:X, :pool),Tuple{Array{Float64,2},CategoricalPool{String,UInt8,CategoricalValue{String,UInt8}}}}

Abstractly, we need a mixture of instantiations of pred, one component for each value of post.

There's a bit more to it though, because we need to return just the last distribution, so the result will (in most cases) no longer be a Soss model. This part will require a new Soss method, which I can put together.

This will get us to "mixture over the response distributions". Then for the special case of UnivariateFinite (and also Categorical and Multinomial), we'll need a method that says a mixture of UnivariateFinites is just another UnivariateFinite.

This won't just be any mixture, the components will have equal weight. I have an EqualMix in Soss that will at least be a good starting point for this.

cscherrer commented 4 years ago

Think I'm getting close...

Say you start with from p=predictor_joint from example-linear-regression.jl.

Then we can mess with the predictive distribution

julia> p.pred
@model (X, σ, β) begin
        η = X * β
        μ = η
        y ~ For(identity, Normal.(μ, σ))
    end

to get

julia> newpred = Soss.before(Soss.withdistributions(p.pred), p.model.response; strict=true, inclusive=false)
@model (X, σ, β) begin
        η = X * β
        μ = η
        _y_dist = For(identity, Normal.(μ, σ))
    end

Then with a little marginals function like

function marginals(d::For)
    return d.f.(d.θ)
end

we can get

julia> mar = marginals(rand(newpred(merge(p.args, particles(p.post))))._y_dist);

julia> typeof(mar)
Array{Normal{Particles{Float64,1000}},1}

julia> mar[1]
Normal{Particles{Float64,1000}}(
μ: -0.229 ± 0.0094
σ: 0.142 ± 0.0032
)

This is not quite what we want, but seems very close. And (as always with particles) I really like how clean and easy-to-read the representation is. Maybe we need something like

struct ParticleMixture{D,X} <: Distribution
    f :: D # the constructor, e.g. `Normal`
    pars :: X
end

So this would have the same data as f(pars...), but would allow us to write proper rand and logpdf methods. Hmm..., actually this would be more natural as part of MonteCarloMethods. I'll think a bit more and then start an issue there for it.

cscherrer commented 4 years ago

Just remembered about https://github.com/baggepinnen/MonteCarloMeasurements.jl/issues/22

Lots of great background there, need to reread it myself :)

@DilumAluthge let's go ahead with the release and update as this moves ahead.

DilumAluthge commented 4 years ago

I have created the Prediction project to keep track of progress on this issue.

DilumAluthge commented 4 years ago

I'm going to mark this as potentially breaking, since it will probably require some changes to the return types of public functions.