cscherrer / SossMLJ.jl

SossMLJ makes it easy to build MLJ machines from user-defined models from the Soss probabilistic programming language
https://cscherrer.github.io/SossMLJ.jl/stable/
MIT License
15 stars 1 forks source link

Implement a simple loss function #60

Closed cscherrer closed 3 years ago

cscherrer commented 3 years ago

Tangentially, would this really be the way to use it? I'd think expected loss is a lot more useful than loss at expected value.

I 100% agree. I only did this PR so I could quickly get cross-validation up and running. See e.g.:

https://github.com/cscherrer/SossMLJ.jl/blob/cbfcd9be7792819f0a53d464da01b064b08803af/examples/example-bayesian-linear-regression.jl#L89-L89

But I agree that the better way to do this is to take the entire joint posterior predictive distribution and push it through the loss function, giving us a vector of losses that is an empiric distribution over the loss. Then we can e.g. take the mean or median of that vector of losses.

Unfortunately, I do not know how to implement this. @cscherrer Could you, as an example, implement this for a simple loss function? Then, I can try to implement it for another loss function.

The interface I am imagining is this. We first implement a method that looks like this:

function LOSSNAME_dist(sp::SossMLJPredictor{M}, y_true) where {M}
    # do the computation
    # ... 
    # ...
    return vector_of_losses # vector_of_losses is of type `Vector{Float64}`
end

And then also these two convenience methods:

function LOSSNAME_mean(sp::SossMLJPredictor{M}, y_true) where {M}
    vector_of_losses = LOSSNAME_dist(sp, y_true)
    return Statistics.mean(vector_of_losses)
end

function LOSSNAME_median(sp::SossMLJPredictor{M}, y_true) where {M}
    vector_of_losses = LOSSNAME_dist(sp, y_true)
    return Statistics.median(vector_of_losses)
end

So, my workflow looks something like this:

mach = machine(model, X, y)
fit!(mach)
predictor_joint = MLJ.predict_joint(mach, X)
LOSSNAME_dist(predictor_joint)    # this returns a `Vector{Float64}`
LOSSNAME_mean(predictor_joint)    # this returns a `Float64`
LOSSNAME_median(predictor_joint)  # this returns a `Float64`

Originally posted by @DilumAluthge in #55 (comment)

cscherrer commented 3 years ago

I think we can use the particles trick from earlier to help here. Evaluating the loss on particles will give us new particles, an we can apply mean etc to that

DilumAluthge commented 3 years ago

I think we can use the particles trick from earlier to help here. Evaluating the loss on particles will give us new particles, an we can apply mean etc to that

Sounds good to me.

cscherrer commented 3 years ago

This turns out to be really easy:

julia> rms( predict_particles(predictor_joint, X).yhat, y)
Particles{Float64,1000}
 0.956043 ± 0.0122

julia> rms( predict_particles(predictor_joint, X).yhat, y) |> mean
0.956042809853842

Guess that means we're getting some good machinery in place :)

Personally, I'd prefer the former. Do we need to return a Float64 to be compatible with other aspects of MLJ?

cscherrer commented 3 years ago

Oh we could have

rms(predictor_joint, y) = mean(rms( predict_particles(predictor_joint, X).yhat, y))

or something

DilumAluthge commented 3 years ago

I think we'll need both.

Like you said, the correct way to do it is the former, because we get back the uncertainty of the loss.

But for e.g. cross-validation, I think that MLJ will expect a float.

cscherrer commented 3 years ago

Would it be valid to have a kwarg for this? we could default to float if we need to

DilumAluthge commented 3 years ago

Would it be valid to have a kwarg for this? we could default to float if we need to

For CV, you don't actually call rms directly. You pass rms (which is an instance of a callable struct) to MLJ, and MLJ will call rms.

DilumAluthge commented 3 years ago

There is another issue to discuss.

Things like rms live in MLJBase. So we are overloading methods that live in MLJBase, which means we need to do one of the following three things:

  1. Keep MLJBase as a direct dependency of SossMLJ.
  2. Use Requires to make MLJBase a conditional dependency of SossMLJ.
  3. Make a separate package SossMLJMeasures that contains the functionality that requires MLJBase.

I'm not a fan of number 1. MLJBase is a pretty heavy dependency. It would be ideal if MLJModelInterface is the only MLJ-related dependency that SossMLJ has.

Number 2 is also not ideal. Requires has a lot of problems. In particular, you can't provide any compatibility information. We are hopefully going to eventually have first-class support for conditional dependencies in Pkg and Base Julia, but that won't happen before Julia 1.6, which is going to be a long way away.

So I think we may be stuck with number 3.

DilumAluthge commented 3 years ago

Would it be valid to have a kwarg for this? we could default to float if we need to

For CV, you don't actually call rms directly. You pass rms (which is an instance of a callable struct) to MLJ, and MLJ will call rms.

To elaborate, here is one way to do CV for our Bayesian linear regression model:

MLJ.evaluate!(mach, resampling=MLJ.CV(; shuffle = true), measure=rms, operation=predict_mean)

Here, MLJBase has defined const rms = RMS().

RMS is a callable struct.

DilumAluthge commented 3 years ago
julia> MLJBase.rms == MLJBase.RMS()
true
DilumAluthge commented 3 years ago

There is another issue to discuss.

Things like rms live in MLJBase. So we are overloading methods that live in MLJBase, which means we need to do one of the following three things:

1. Keep `MLJBase` as a direct dependency of `SossMLJ`.

2. Use `Requires` to make `MLJBase` a conditional dependency of `SossMLJ`.

3. Make a separate package `SossMLJMeasures` that contains the functionality that requires `MLJBase`.

I'm not a fan of number 1. MLJBase is a pretty heavy dependency. It would be ideal if MLJModelInterface is the only MLJ-related dependency that SossMLJ has.

Number 2 is also not ideal. Requires has a lot of problems. In particular, you can't provide any compatibility information. We are hopefully going to eventually have first-class support for conditional dependencies in Pkg and Base Julia, but that won't happen before Julia 1.6, which is going to be a long way away.

So I think we may be stuck with number 3.

Actually... it's not just loss functions. For forwarding stuff through machines, we also need MLJBase.

I think we need to do number 1. Keep all the functionality in one place in SossMLJ.jl. And yes, we have to keep a dependency on MLJBase. It is a heavy dependency, but we need the functionality.