Closed mohamed82008 closed 1 year ago
Turns out I messed up the definition of the multi-target loss. So this PR is not necessary at all.
@mohamed82008 Thanks for spending time with this packages and going to the trouble of contributing a PR.
I wonder if your confusion might be something others will also trip up on. If you can say a few words about what the confusion was, I will open an issue to address this in future doc fixes.
The problem is that most loss functions don't support multiple targets by default so when using multi-target regression from MLJFlux, I have to either use their built-in losses or roll my own implementation which I did to be able to try out more loss functions, e.g. from LossFunctions.jl. The problem is that the implementation of the multi_target
function used in the tests in this PR was wrong so it was returning a vector sometimes. Somehow this didn't break the optimisation but it broke the report making. The following implementation made things work:
function multi_target(loss)
function _loss(x1::Real, x2::Real)
return loss(x1, x2)
end
function _loss(x1, x2::NamedTuple)
sum(map(x1, x2) do _x1, _x2
sum(loss(_x1, _x2))
end)
end
function _loss(x1::Matrix, x2::Matrix)
sum(loss(vec(x1), vec(x2)))
end
return _loss
end
Perhaps my life would have been easier if this vector-output behaviour was detected and reported way before the report generation, e.g. when building the history array. Also, the above function can be defined and documented to give people a way to change single-target loss functions to multi-target ones.
Thanks for this explanation. It seems weird that the flawed custom loss did not break the optimization. I think that is worth investigating. I will open an issue.
Also, the above function can be defined and documented to give people a way to change single-target loss functions to multi-target ones.
Yes, adding wrappers for multi-target losses is on the list: https://github.com/JuliaAI/MLJBase.jl/issues/502
Hi! Thanks for this really cool package. I noticed the multi-target MLJ example in this PR is not working because of an eltype error in
measurements
vector which triggers an error when the report is generated. I "fixed" it here in this PR and added a test but I am not sure if I was doing something wrong to begin with to trigger this error or not as I am still a novice in MLJ. Please let me know if my example is wrong for some reason or if the fix here is appropriate.