Closed cscherrer closed 3 years ago
I think we can use the particles
trick from earlier to help here. Evaluating the loss on particles will give us new particles, an we can apply mean
etc to that
I think we can use the
particles
trick from earlier to help here. Evaluating the loss on particles will give us new particles, an we can applymean
etc to that
Sounds good to me.
This turns out to be really easy:
julia> rms( predict_particles(predictor_joint, X).yhat, y)
Particles{Float64,1000}
0.956043 ± 0.0122
julia> rms( predict_particles(predictor_joint, X).yhat, y) |> mean
0.956042809853842
Guess that means we're getting some good machinery in place :)
Personally, I'd prefer the former. Do we need to return a Float64
to be compatible with other aspects of MLJ?
Oh we could have
rms(predictor_joint, y) = mean(rms( predict_particles(predictor_joint, X).yhat, y))
or something
I think we'll need both.
Like you said, the correct way to do it is the former, because we get back the uncertainty of the loss.
But for e.g. cross-validation, I think that MLJ will expect a float.
Would it be valid to have a kwarg for this? we could default to float if we need to
Would it be valid to have a kwarg for this? we could default to float if we need to
For CV, you don't actually call rms
directly. You pass rms
(which is an instance of a callable struct) to MLJ, and MLJ will call rms
.
There is another issue to discuss.
Things like rms
live in MLJBase
. So we are overloading methods that live in MLJBase
, which means we need to do one of the following three things:
MLJBase
as a direct dependency of SossMLJ
.Requires
to make MLJBase
a conditional dependency of SossMLJ
.SossMLJMeasures
that contains the functionality that requires MLJBase
.I'm not a fan of number 1. MLJBase
is a pretty heavy dependency. It would be ideal if MLJModelInterface
is the only MLJ-related dependency that SossMLJ
has.
Number 2 is also not ideal. Requires
has a lot of problems. In particular, you can't provide any compatibility information. We are hopefully going to eventually have first-class support for conditional dependencies in Pkg and Base Julia, but that won't happen before Julia 1.6, which is going to be a long way away.
So I think we may be stuck with number 3.
Would it be valid to have a kwarg for this? we could default to float if we need to
For CV, you don't actually call
rms
directly. You passrms
(which is an instance of a callable struct) to MLJ, and MLJ will callrms
.
To elaborate, here is one way to do CV for our Bayesian linear regression model:
MLJ.evaluate!(mach, resampling=MLJ.CV(; shuffle = true), measure=rms, operation=predict_mean)
Here, MLJBase has defined const rms = RMS()
.
RMS
is a callable struct.
julia> MLJBase.rms == MLJBase.RMS()
true
There is another issue to discuss.
Things like
rms
live inMLJBase
. So we are overloading methods that live inMLJBase
, which means we need to do one of the following three things:1. Keep `MLJBase` as a direct dependency of `SossMLJ`. 2. Use `Requires` to make `MLJBase` a conditional dependency of `SossMLJ`. 3. Make a separate package `SossMLJMeasures` that contains the functionality that requires `MLJBase`.
I'm not a fan of number 1.
MLJBase
is a pretty heavy dependency. It would be ideal ifMLJModelInterface
is the only MLJ-related dependency thatSossMLJ
has.Number 2 is also not ideal.
Requires
has a lot of problems. In particular, you can't provide any compatibility information. We are hopefully going to eventually have first-class support for conditional dependencies in Pkg and Base Julia, but that won't happen before Julia 1.6, which is going to be a long way away.So I think we may be stuck with number 3.
Actually... it's not just loss functions. For forwarding stuff through machines, we also need MLJBase.
I think we need to do number 1. Keep all the functionality in one place in SossMLJ.jl. And yes, we have to keep a dependency on MLJBase. It is a heavy dependency, but we need the functionality.
Originally posted by @DilumAluthge in #55 (comment)